diff --git a/Cargo.lock b/Cargo.lock index 327dcbe..d03f181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,9 +121,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "base64" @@ -142,9 +142,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.11.1" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +checksum = "84d7ced0ae9557296835c32bf1b1e02b44c746701f898460fb000d7eaa84f00a" [[package]] name = "blake2" @@ -210,9 +210,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.61" +version = "1.2.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" dependencies = [ "find-msvc-tools", "shlex", @@ -384,7 +384,7 @@ dependencies = [ [[package]] name = "defguard_wireguard_rs" -version = "0.9.7" +version = "0.10.0" dependencies = [ "base64", "defguard_boringtun", @@ -398,7 +398,6 @@ dependencies = [ "netlink-packet-utils", "netlink-packet-wireguard", "netlink-sys", - "nix", "regex", "serde", "serde_test", @@ -555,9 +554,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "heck" @@ -593,7 +592,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.17.0", + "hashbrown 0.17.1", "serde", "serde_core", ] @@ -649,9 +648,9 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" -version = "0.2.24" +version = "0.2.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" +checksum = "4603d3033e49e2b0e31229fcab20a5d40089c607d975cd9c80551dc69eed9102" dependencies = [ "jiff-static", "log", @@ -662,9 +661,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.24" +version = "0.2.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" +checksum = "782d32378dddf207193ac91cefb848ad41abb58195c95168e1291227a0832b47" dependencies = [ "proc-macro2", "quote", @@ -716,24 +715,15 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "113b30b4cd05f7c06868fdb2854f66a7b9fece9a48425351cd532e810d74024f" [[package]] name = "memchr" -version = "2.8.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" - -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] +checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" [[package]] name = "minimal-lexical" @@ -784,10 +774,11 @@ dependencies = [ [[package]] name = "netlink-packet-wireguard" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037892b0e01ce41f30398a47be2051e712a2cf1eed9cb7e5e6a92b05c423255b" +checksum = "cb217ca08c02978c3ea5ef0f013d20c602f2a17a9f00d9e4b1518a3500c2f332" dependencies = [ + "bitflags", "libc", "log", "netlink-packet-core", @@ -807,15 +798,14 @@ dependencies = [ [[package]] name = "nix" -version = "0.31.2" +version = "0.31.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +checksum = "cf20d2fde8ff38632c426f1165ed7436270b44f199fc55284c38276f9db47c3d" dependencies = [ "bitflags", "cfg-if", "cfg_aliases", "libc", - "memoffset", ] [[package]] @@ -1125,9 +1115,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -1165,9 +1155,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "siphasher" @@ -1189,9 +1179,9 @@ checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -1327,7 +1317,7 @@ version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ - "winnow 1.0.2", + "winnow 1.0.3", ] [[package]] @@ -1395,9 +1385,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "unicode-ident" @@ -1863,9 +1853,9 @@ dependencies = [ [[package]] name = "winnow" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +checksum = "0592e1c9d151f854e6fd382574c3a0855250e1d9b2f99d9281c6e6391af352f1" [[package]] name = "wireguard-nt" diff --git a/Cargo.toml b/Cargo.toml index a285cf1..9fd902a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "defguard_wireguard_rs" -version = "0.9.7" +version = "0.10.0" edition = "2024" rust-version = "1.87" description = "A unified multi-platform high-level API for managing WireGuard interfaces" @@ -29,7 +29,6 @@ defguard_boringtun = { version = "0.6", default-features = false, features = [ "device", ]} libc = { version = "0.2", default-features = false } -nix = { version = "0.31", features = ["ioctl", "socket"] } [target.'cfg(target_os = "windows")'.dependencies] ipnet = "2.11" @@ -46,7 +45,7 @@ netlink-packet-core = "0.8" netlink-packet-generic = "0.4" netlink-packet-route = "0.30" netlink-packet-utils = "0.6" -netlink-packet-wireguard = "0.3" +netlink-packet-wireguard = "0.4" netlink-sys = "0.8" [target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))'.dependencies] diff --git a/src/bsd/ifconfig.rs b/src/bsd/ifconfig.rs index a28212d..7dbcc3d 100644 --- a/src/bsd/ifconfig.rs +++ b/src/bsd/ifconfig.rs @@ -3,72 +3,65 @@ use std::{ os::fd::AsRawFd, }; -use libc::{IF_NAMESIZE, IFF_UP}; -use nix::{ioctl_readwrite, ioctl_write_ptr, sys::socket::AddressFamily}; +use libc::{AF_INET, AF_INET6, AF_UNIX, IF_NAMESIZE, IFF_UP, c_ulong, ioctl}; use super::{ - IoError, create_socket, + IoError, c_int_to_error, create_socket, + ioctl::{iow, iowr}, sockaddr::{SockAddrIn, SockAddrIn6}, }; // From `netinet6/in6.h`. const ND6_INFINITE_LIFETIME: u32 = u32::MAX; -// SIOCIFDESTROY -ioctl_write_ptr!(destroy_clone_if, b'i', 121, IfReq); +// From `sys/sockio.h`. +const SIOCIFDESTROY: c_ulong = iow::(b'i', 121); -// SIOCIFCREATE2 -// FIXME: not on NetBSD -ioctl_readwrite!(create_clone_if, b'i', 124, IfReq); +// Note: SIOCIFCREATE on NetBSD should work as SIOCIFCREATE2 on FreeBSD. +#[cfg(target_os = "netbsd")] +const SIOCIFCREATE: c_ulong = iowr::(b'i', 122); + +// SIOCIFCREATE2 works as SIOCIFCREATE, but let the caller speficy the interface name. +#[cfg(target_os = "freebsd")] +const SIOCIFCREATE2: c_ulong = iowr::(b'i', 124); +#[cfg(target_os = "macos")] +const SIOCIFCREATE2: c_ulong = iowr::(b'i', 124); -// SIOCGIFMTU #[cfg(any(target_os = "freebsd", target_os = "macos"))] -ioctl_readwrite!(get_if_mtu, b'i', 51, IfMtu); +const SIOCGIFMTU: c_ulong = iowr::(b'i', 51); #[cfg(target_os = "netbsd")] -ioctl_readwrite!(get_if_mtu, b'i', 126, IfMtu); +const SIOCGIFMTU: c_ulong = iowr::(b'i', 126); -// SIOCSIFMTU #[cfg(any(target_os = "freebsd", target_os = "macos"))] -ioctl_write_ptr!(set_if_mtu, b'i', 52, IfMtu); +const SIOCSIFMTU: c_ulong = iow::(b'i', 52); #[cfg(target_os = "netbsd")] -ioctl_write_ptr!(set_if_mtu, b'i', 127, IfMtu); +const SIOCSIFMTU: c_ulong = iow::(b'i', 127); -// SIOCSIFADDR -ioctl_write_ptr!(set_addr_if, b'i', 12, IfReq); +const SIOCSIFADDR: c_ulong = iow::(b'i', 12); -// SIOCAIFADDR #[cfg(target_os = "freebsd")] -ioctl_write_ptr!(add_addr_if, b'i', 43, InAliasReq); +const SIOCAIFADDR: c_ulong = iow::(b'i', 43); #[cfg(any(target_os = "macos", target_os = "netbsd"))] -ioctl_write_ptr!(add_addr_if, b'i', 26, InAliasReq); +const SIOCAIFADDR: c_ulong = iow::(b'i', 26); -// SIOCDIFADDR -ioctl_write_ptr!(del_addr_if, b'i', 25, IfReq); +const SIOCDIFADDR: c_ulong = iow::(b'i', 25); +const SIOCSIFADDR_IN6: c_ulong = iow::(b'i', 12); -// SIOCSIFADDR_IN6 -ioctl_write_ptr!(set_addr_if_in6, b'i', 12, IfReq6); - -// SIOCAIFADDR_IN6 #[cfg(target_os = "freebsd")] -ioctl_write_ptr!(add_addr_if_in6, b'i', 27, In6AliasReq); +const SIOCAIFADDR_IN6: c_ulong = iow::(b'i', 27); #[cfg(target_os = "macos")] -ioctl_write_ptr!(add_addr_if_in6, b'i', 26, In6AliasReq); +const SIOCAIFADDR_IN6: c_ulong = iow::(b'i', 26); #[cfg(target_os = "netbsd")] -ioctl_write_ptr!(add_addr_if_in6, b'i', 107, In6AliasReq); - -// SIOCDIFADDR_IN6 -ioctl_write_ptr!(del_addr_if_in6, b'i', 25, IfReq6); - -// SIOCSIFFLAGS -ioctl_write_ptr!(set_if_flags, b'i', 16, IfReqFlags); +const SIOCAIFADDR_IN6: c_ulong = iow::(b'i', 107); -// SIOCGIFFLAGS -ioctl_readwrite!(get_if_flags, b'i', 17, IfReqFlags); +const SIOCDIFADDR_IN6: c_ulong = iow::(b'i', 25); +const SIOCSIFFLAGS: c_ulong = iow::(b'i', 16); +const SIOCGIFFLAGS: c_ulong = iowr::(b'i', 17); type IfName = [u8; IF_NAMESIZE]; fn make_ifr_name(if_name: &str) -> IfName { - let mut ifr_name = [0u8; IF_NAMESIZE]; + let mut ifr_name = [0; IF_NAMESIZE]; let len = if_name.len().min(IF_NAMESIZE - 1); ifr_name[..len].copy_from_slice(&if_name.as_bytes()[..len]); ifr_name @@ -99,45 +92,43 @@ impl IfReq { } pub(super) fn create(&mut self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; - - unsafe { - create_clone_if(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_UNIX)?; + #[cfg(target_os = "netbsd")] + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCIFCREATE, &*self) }; + #[cfg(any(target_os = "freebsd", target_os = "macos"))] + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCIFCREATE2, &*self) }; + c_int_to_error(result)?; Ok(()) } pub(super) fn destroy(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; - - unsafe { - destroy_clone_if(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_UNIX)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCIFDESTROY, self) }; + c_int_to_error(result)?; Ok(()) } pub(super) fn set_address(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Inet).map_err(IoError::WriteIo)?; - unsafe { - set_addr_if(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_INET)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCSIFADDR, self) }; + c_int_to_error(result)?; Ok(()) } pub(super) fn delete_address(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Inet).map_err(IoError::WriteIo)?; - unsafe { - del_addr_if(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_INET)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCDIFADDR, self) }; + c_int_to_error(result)?; Ok(()) } } /// Represent `struct ifreq` as defined in `net/if.h` - ifr_mtu variant. +#[derive(Debug)] #[repr(C)] pub struct IfMtu { ifr_name: IfName, @@ -156,22 +147,18 @@ impl IfMtu { } pub(super) fn get_mtu(&mut self) -> Result { - let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; - - unsafe { - get_if_mtu(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_UNIX)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCGIFMTU, &*self) }; + c_int_to_error(result)?; Ok(self.ifru_mtu) } pub(super) fn set_mtu(&mut self, mtu: u32) -> Result<(), IoError> { self.ifru_mtu = mtu; - let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; - - unsafe { - set_if_mtu(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_UNIX)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCSIFMTU, &*self) }; + c_int_to_error(result)?; Ok(()) } @@ -196,20 +183,17 @@ impl IfReq6 { } pub(super) fn set_address(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Inet6).map_err(IoError::WriteIo)?; - - unsafe { - set_addr_if_in6(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_INET6)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCSIFADDR_IN6, self) }; + c_int_to_error(result)?; Ok(()) } pub(super) fn delete_address(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Inet6).map_err(IoError::WriteIo)?; - unsafe { - del_addr_if_in6(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_INET6)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCDIFADDR_IN6, self) }; + c_int_to_error(result)?; Ok(()) } @@ -245,11 +229,9 @@ impl InAliasReq { } pub(super) fn add_address(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Inet).map_err(IoError::WriteIo)?; - - unsafe { - add_addr_if(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_INET)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCAIFADDR, self) }; + c_int_to_error(result)?; Ok(()) } @@ -296,11 +278,9 @@ impl In6AliasReq { } pub(super) fn add_address(&self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Inet6).map_err(IoError::WriteIo)?; - - unsafe { - add_addr_if_in6(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let socket = create_socket(AF_INET6)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCAIFADDR_IN6, self) }; + c_int_to_error(result)?; Ok(()) } @@ -325,18 +305,15 @@ impl IfReqFlags { } pub(super) fn up(&mut self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; + let socket = create_socket(AF_UNIX)?; // Get current interface flags. - unsafe { - get_if_flags(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let _result = unsafe { ioctl(socket.as_raw_fd(), SIOCGIFFLAGS, &*self) }; // Set interface up flag. self.ifr_flags |= IFF_UP as u64; - unsafe { - set_if_flags(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?; - } + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCSIFFLAGS, &*self) }; + c_int_to_error(result)?; Ok(()) } diff --git a/src/bsd/ioctl.rs b/src/bsd/ioctl.rs new file mode 100644 index 0000000..2efbcbb --- /dev/null +++ b/src/bsd/ioctl.rs @@ -0,0 +1,46 @@ +//! Compute values for ioctl. +//! +//! Same as libc::{_IOR, _IOW, IOWR}, but these are not available for all platforms. +//! +//! | dir | size | type | nr | +//! |--------|---------|--------|--------| +//! | 31–30 | 29–16 | 15–8 | 7–0 | +//! | 2 bits | 14 bits | 8 bits | 8 bits | + +use std::mem; + +use libc::c_ulong; + +/// Direction (out = read, in = write). +const IOC_VOID: u32 = 0x2000_0000; +const IOC_OUT: u32 = 0x4000_0000; +const IOC_IN: u32 = 0x8000_0000; + +/// Equivalent to the C _IOC() macro. +const fn ioc(dir: u32, ty: u8, nr: u32, size: u32) -> c_ulong { + (dir | (size << 16) | ((ty as u32) << 8) | nr) as c_ulong +} + +/// IO — no data transfer +#[must_use] +pub const fn io(ty: u8, nr: u32) -> c_ulong { + ioc(IOC_VOID, ty, nr, 0) +} + +/// IOR — kernel to userspace (read) +#[must_use] +pub const fn ior(ty: u8, nr: u32) -> c_ulong { + ioc(IOC_IN, ty, nr, mem::size_of::() as u32) +} + +/// IOW — userspace to kernel (write) +#[must_use] +pub const fn iow(ty: u8, nr: u32) -> c_ulong { + ioc(IOC_IN, ty, nr, mem::size_of::() as u32) +} + +/// IOWR — bidirectional communication (write-read) +#[must_use] +pub const fn iowr(ty: u8, nr: u32) -> c_ulong { + ioc(IOC_OUT | IOC_IN, ty, nr, mem::size_of::() as u32) +} diff --git a/src/bsd/mod.rs b/src/bsd/mod.rs index 2dfee1d..8c6e57f 100644 --- a/src/bsd/mod.rs +++ b/src/bsd/mod.rs @@ -1,24 +1,25 @@ mod ifconfig; +pub mod ioctl; mod nvlist; mod route; mod sockaddr; +#[cfg(test)] +mod tests; mod timespec; mod wgio; use std::{ collections::HashMap, ffi::{CStr, CString}, + io, mem::{MaybeUninit, size_of}, net::IpAddr, - os::fd::OwnedFd, + os::fd::{FromRawFd, OwnedFd}, ptr::from_ref, slice::from_raw_parts, }; -use nix::{ - errno::Errno, - sys::socket::{AddressFamily, SockFlag, SockType, socket}, -}; +use libc::{IPPROTO_IP, SOCK_DGRAM, c_int, socket}; use route::{DestAddrMask, GatewayLink}; use sockaddr::{SockAddrDl, SockAddrIn, SockAddrIn6, SocketFromRaw}; use thiserror::Error; @@ -72,19 +73,29 @@ unsafe fn cast_bytes(p: &T) -> &[u8] { unsafe { from_raw_parts(from_ref::(p).cast::(), size_of::()) } } +/// Convert result of -1 to `IoError`, by taking `errno`. +pub fn c_int_to_error(result: c_int) -> Result<(), io::Error> { + if result == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } +} + /// Create socket for ioctl communication. -fn create_socket(address_family: AddressFamily) -> Result { - socket(address_family, SockType::Datagram, SockFlag::empty(), None) +fn create_socket(address_family: c_int) -> Result { + match unsafe { socket(address_family, SOCK_DGRAM, IPPROTO_IP) } { + -1 => Err(io::Error::last_os_error()), + fd => unsafe { Ok(OwnedFd::from_raw_fd(fd)) }, + } } #[derive(Debug, Error)] pub enum IoError { #[error("Memory allocation error")] MemAlloc, - #[error("Read error {0}")] - ReadIo(Errno), - #[error("Write error {0}")] - WriteIo(Errno), + #[error("I/O error {0}")] + Io(io::Error), #[error("Network interface does not exist")] NetworkInterface, #[error("Not enough bytes to unpack")] @@ -93,6 +104,12 @@ pub enum IoError { KernelModule, } +impl From for IoError { + fn from(error: io::Error) -> Self { + Self::Io(error) + } +} + impl From for WireguardInterfaceError { fn from(error: IoError) -> Self { WireguardInterfaceError::BsdError(error.to_string()) diff --git a/src/bsd/route.rs b/src/bsd/route.rs index f74dce1..d87fb19 100644 --- a/src/bsd/route.rs +++ b/src/bsd/route.rs @@ -1,18 +1,15 @@ use std::{ ffi::{CStr, CString}, + io, mem::{MaybeUninit, size_of}, net::IpAddr, - os::fd::{AsFd, AsRawFd}, + os::fd::{AsRawFd, FromRawFd, OwnedFd}, }; -use nix::{ - errno::Errno, - sys::socket::{AddressFamily, Shutdown, SockFlag, SockType, shutdown, socket}, - unistd::{read, write}, -}; +use libc::{ESRCH, PF_ROUTE, SHUT_RD, SOCK_RAW, read, shutdown, socket, write}; use super::{ - IoError, cast_bytes, cast_ref, + IoError, c_int_to_error, cast_bytes, cast_ref, sockaddr::{SockAddrDl, SocketFromRaw, unpack_sockaddr}, }; @@ -351,30 +348,54 @@ impl RtMessage { Self { header, payload } } + /// Create socket for ioctl communication. + fn socket() -> Result { + match unsafe { socket(PF_ROUTE, SOCK_RAW, 0) } { + -1 => Err(io::Error::last_os_error()), + fd => unsafe { Ok(OwnedFd::from_raw_fd(fd)) }, + } + } + pub(super) fn execute(&self) -> Result<(), IoError> { - let socket = socket(AddressFamily::Route, SockType::Raw, SockFlag::empty(), None) - .map_err(IoError::WriteIo)?; + let socket = Self::socket()?; // Don't want to read back our messages. - shutdown(socket.as_raw_fd(), Shutdown::Read).map_err(IoError::WriteIo)?; + let result = unsafe { shutdown(socket.as_raw_fd(), SHUT_RD) }; + c_int_to_error(result)?; let buf = unsafe { cast_bytes(self) }; - match write(socket.as_fd(), buf) { - Ok(_) | Err(Errno::ESRCH) => Ok(()), // not in table - Err(err) => Err(IoError::WriteIo(err)), + let result = unsafe { write(socket.as_raw_fd(), buf.as_ptr().cast(), buf.len()) }; + if result == -1 { + let err = io::Error::last_os_error(); + if let Some(raw_os_err) = err.raw_os_error() + && raw_os_err == ESRCH + { + // not in table + } else { + return Err(err)?; + } } + + Ok(()) } pub(super) fn get_gateway(&self) -> Result, IoError> { - let socket = socket(AddressFamily::Route, SockType::Raw, SockFlag::empty(), None) - .map_err(IoError::WriteIo)?; + let socket = Self::socket()?; let buf = unsafe { cast_bytes(self) }; - match write(socket.as_fd(), buf) { - Ok(_) => (), - Err(Errno::ESRCH) => return Ok(None), // not in table - Err(err) => return Err(IoError::WriteIo(err)), + let result = unsafe { write(socket.as_raw_fd(), buf.as_ptr().cast(), buf.len()) }; + if result == -1 { + let err = io::Error::last_os_error(); + if let Some(raw_os_err) = err.raw_os_error() + && raw_os_err == ESRCH + { + return Ok(None); // not in table + } + return Err(err)?; } let mut buf = [0u8; 256]; // FIXME: fixed buffer size - let len = read(socket.as_fd(), &mut buf).map_err(IoError::ReadIo)?; + let len = match unsafe { read(socket.as_raw_fd(), buf.as_mut_ptr().cast(), buf.len()) } { + -1 => return Err(io::Error::last_os_error())?, + result => result.cast_unsigned(), + }; if len < size_of::() { return Err(IoError::Unpack); } diff --git a/src/bsd/tests.rs b/src/bsd/tests.rs new file mode 100644 index 0000000..87b3e9c --- /dev/null +++ b/src/bsd/tests.rs @@ -0,0 +1,20 @@ +use std::net::Ipv4Addr; + +use super::*; + +#[ignore = "requires root access"] +#[test] +fn test_assign_address() { + let ifname = "lo0"; + let address = IpAddrMask::new(IpAddr::V4(Ipv4Addr::new(127, 1, 1, 1)), 8); + assign_address(ifname, &address).unwrap(); + // TODO: get_address() + remove_address(ifname, &address).unwrap(); +} + +#[test] +fn test_get_mtu() { + let ifname = "lo0"; + let mtu = get_mtu(ifname).unwrap(); + assert_eq!(mtu, 16384); +} diff --git a/src/bsd/wgio.rs b/src/bsd/wgio.rs index 8057506..22da516 100644 --- a/src/bsd/wgio.rs +++ b/src/bsd/wgio.rs @@ -5,14 +5,14 @@ use std::{ slice::from_raw_parts, }; -use libc::IF_NAMESIZE; -use nix::{ioctl_readwrite, sys::socket::AddressFamily}; +use libc::{AF_UNIX, IF_NAMESIZE, c_ulong, ioctl}; -use super::{IoError, create_socket}; +use super::{IoError, c_int_to_error, create_socket, ioctl::iowr}; // FIXME: `WgReadIo` and `WgWriteIo` have to be declared public. -ioctl_readwrite!(write_wireguard_data, b'i', 210, WgWriteIo); -ioctl_readwrite!(read_wireguard_data, b'i', 211, WgReadIo); +// From `dev/wg/if_wg.h`. +const SIOCSWG: c_ulong = iowr::(b'i', 210); +const SIOCGWG: c_ulong = iowr::(b'i', 211); /// Represent `struct wg_data_io` defined in /// https://github.com/freebsd/freebsd-src/blob/main/sys/dev/wg/if_wg.h @@ -59,21 +59,17 @@ impl WgReadIo { } pub(super) fn read_data(&mut self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Unix).map_err(IoError::ReadIo)?; - unsafe { - // First do ioctl with empty `wg_data` to obtain buffer size. - if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) { - error!("WgReadIo first read error {err}"); - return Err(IoError::ReadIo(err)); - } - // Allocate buffer. - self.alloc_data()?; - // Second call to ioctl with allocated buffer. - if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) { - error!("WgReadIo second read error {err}"); - return Err(IoError::ReadIo(err)); - } - } + let socket = create_socket(AF_UNIX)?; + + // First do ioctl with empty `wg_data` to obtain buffer size. + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCGWG, &*self) }; + c_int_to_error(result)?; + + // Allocate buffer. + self.alloc_data()?; + // Second call to ioctl with allocated buffer. + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCGWG, &*self) }; + c_int_to_error(result)?; Ok(()) } @@ -116,13 +112,9 @@ impl WgWriteIo { } pub(super) fn write_data(&mut self) -> Result<(), IoError> { - let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; - unsafe { - if let Err(err) = write_wireguard_data(socket.as_raw_fd(), self) { - error!("WgWriteIo write error {err}"); - return Err(IoError::WriteIo(err)); - } - } + let socket = create_socket(AF_UNIX)?; + let result = unsafe { ioctl(socket.as_raw_fd(), SIOCSWG, &*self) }; + c_int_to_error(result)?; Ok(()) } diff --git a/src/host.rs b/src/host.rs index bea7891..d7767e2 100644 --- a/src/host.rs +++ b/src/host.rs @@ -13,15 +13,12 @@ use std::{ }; #[cfg(target_os = "linux")] -use netlink_packet_wireguard::WireguardAttribute; +use netlink_packet_wireguard::{WireguardAttribute, WireguardDeviceFlags}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use super::{key::Key, peer::Peer}; -#[cfg(target_os = "linux")] -const WGDEVICE_F_REPLACE_PEERS: u32 = 1; - /// WireGuard host representation. #[derive(Clone, Default)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] @@ -239,7 +236,9 @@ impl Host { if let Some(fwmark) = &self.fwmark { nlas.push(WireguardAttribute::Fwmark(*fwmark)); } - nlas.push(WireguardAttribute::Flags(WGDEVICE_F_REPLACE_PEERS)); + nlas.push(WireguardAttribute::Flags( + WireguardDeviceFlags::ReplacePeers, + )); // IMPORTANT: To avoid buffer overflow, do not add peers here. // let peers = self.peers.values().map(Peer::as_nlas_peer).collect(); diff --git a/src/netlink.rs b/src/netlink.rs index fdec1a6..49fea0a 100644 --- a/src/netlink.rs +++ b/src/netlink.rs @@ -21,6 +21,7 @@ use netlink_packet_route::{ }; use netlink_packet_wireguard::{ WireguardAttribute, WireguardCmd, WireguardMessage, WireguardPeer, WireguardPeerAttribute, + WireguardPeerFlags, }; use netlink_sys::{ Socket, SocketAddr, @@ -31,7 +32,6 @@ use thiserror::Error; use crate::{IpVersion, Key, WireguardInterfaceError, host::Host, net::IpAddrMask, peer::Peer}; const SOCKET_BUFFER_LENGTH: usize = 12288; -const WGPEER_F_REMOVE_ME: u32 = 1; #[derive(Debug, Error)] pub(crate) enum NetlinkError { @@ -88,7 +88,7 @@ impl Key { WireguardAttribute::IfName(ifname.into()), WireguardAttribute::Peers(vec![WireguardPeer(vec![ WireguardPeerAttribute::PublicKey(self.as_array()), - WireguardPeerAttribute::Flags(WGPEER_F_REMOVE_ME), + WireguardPeerAttribute::Flags(WireguardPeerFlags::RemoveMe), ])]), ] } diff --git a/src/peer.rs b/src/peer.rs index 2dd528c..6c83f07 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -8,15 +8,13 @@ use std::{fmt, net::SocketAddr, time::SystemTime}; #[cfg(target_os = "linux")] use netlink_packet_wireguard::{ WireguardAllowedIpAttr, WireguardAttribute, WireguardPeer, WireguardPeerAttribute, + WireguardPeerFlags, }; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::{error::WireguardInterfaceError, key::Key, net::IpAddrMask, utils::resolve}; -#[cfg(target_os = "linux")] -const WGPEER_F_REPLACE_ALLOWEDIPS: u32 = 2; - /// WireGuard peer representation. #[derive(Clone, Default, PartialEq)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] @@ -127,7 +125,7 @@ impl Peer { match nla { WireguardPeerAttribute::PublicKey(value) => peer.public_key = Key::new(*value), WireguardPeerAttribute::PresharedKey(value) => { - peer.preshared_key = Some(Key::new(*value)) + peer.preshared_key = Some(Key::new(*value)); } WireguardPeerAttribute::Endpoint(value) => peer.endpoint = Some(*value), WireguardPeerAttribute::PersistentKeepalive(value) => { @@ -136,7 +134,7 @@ impl Peer { WireguardPeerAttribute::LastHandshake(value) => { let duration = Duration::from_secs(value.seconds.cast_unsigned()) .saturating_add(Duration::from_nanos(value.nano_seconds.cast_unsigned())); - peer.last_handshake = Some(SystemTime::UNIX_EPOCH + duration) + peer.last_handshake = Some(SystemTime::UNIX_EPOCH + duration); } WireguardPeerAttribute::RxBytes(value) => peer.rx_bytes = *value, WireguardPeerAttribute::TxBytes(value) => peer.tx_bytes = *value, @@ -189,7 +187,9 @@ impl Peer { )); } - attrs.push(WireguardPeerAttribute::Flags(WGPEER_F_REPLACE_ALLOWEDIPS)); + attrs.push(WireguardPeerAttribute::Flags( + WireguardPeerFlags::ReplaceAllowedIps, + )); let allowed_ips = self .allowed_ips .iter() diff --git a/src/utils.rs b/src/utils.rs index 25dd64c..63c7821 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -382,8 +382,6 @@ pub(crate) fn add_peer_routing( peers: &[Peer], ifname: &str, ) -> Result<(), WireguardInterfaceError> { - use nix::errno::Errno; - use crate::bsd::{IoError, delete_gateway}; let gateway_v4 = get_gateway(IpVersion::IPv4); @@ -423,8 +421,10 @@ pub(crate) fn add_peer_routing( } match add_linked_route(&default1, ifname) { Ok(()) => debug!("Route to {default1} has been added for interface {ifname}"), - Err(err) => match err { - IoError::WriteIo(Errno::ENETUNREACH) => { + Err(IoError::Io(err)) => { + if let Some(raw_os_err) = err.raw_os_error() + && raw_os_err == libc::ENETUNREACH + { warn!( "Failed to add default route {default1} for interface {ifname}: \ Network is unreachable. This may happen if interface's IP address \ @@ -433,18 +433,22 @@ pub(crate) fn add_peer_routing( ignored. Otherwise, there may be some other issues with network \ configuration." ); - } - _ => { + } else { error!( "Failed to add route to {default1} for interface {ifname}: {err}" ); } - }, + } + Err(err) => { + error!("Failed to add route to {default1} for interface {ifname}: {err}"); + } } match add_linked_route(&default2, ifname) { Ok(()) => debug!("Route to {default2} has been added for interface {ifname}"), - Err(err) => match err { - IoError::WriteIo(Errno::ENETUNREACH) => { + Err(IoError::Io(err)) => { + if let Some(raw_os_err) = err.raw_os_error() + && raw_os_err == libc::ENETUNREACH + { warn!( "Failed to add default route {default2} for interface {ifname}: \ Network is unreachable. This may happen if interface's IP address \ @@ -453,13 +457,15 @@ pub(crate) fn add_peer_routing( ignored. Otherwise, there may be some other issues with network \ configuration." ); - } - _ => { + } else { error!( "Failed to add route to {default2} for interface {ifname}: {err}" ); } - }, + } + Err(err) => { + error!("Failed to add route to {default2} for interface {ifname}: {err}"); + } } } else { // Equivalent to `route -n add -inet[6] -interface `. @@ -614,7 +620,7 @@ pub(crate) fn get_command_path(command: &str) -> Result, Wiregua let paths = env::var_os("PATH").ok_or_else(|| { WireguardInterfaceError::MissingDependency("Environment variable `PATH` not found".into()) })?; - debug!("PATH variable: {paths:?}"); + debug!("PATH variable: {}", paths.display()); Ok(env::split_paths(&paths).find_map(|dir| { let full_path = dir.join(command);