diff --git a/litebox_common_linux/src/lib.rs b/litebox_common_linux/src/lib.rs index abd6eee80..9504adadc 100644 --- a/litebox_common_linux/src/lib.rs +++ b/litebox_common_linux/src/lib.rs @@ -1791,7 +1791,7 @@ pub struct UserMsgHdr { /// number of bytes of ancillary data pub msg_controllen: usize, /// flags on received message - pub msg_flags: SendFlags, + pub msg_flags: ReceiveFlags, /// Explicit trailing padding to match the 4-byte gap after `msg_flags` in /// Linux's naturally-aligned `struct user_msghdr` on 64-bit (total size 56). #[cfg(target_pointer_width = "64")] @@ -2035,6 +2035,11 @@ pub enum SyscallRequest { addr: Option>, addrlen: Platform::RawMutPointer, }, + Recvmsg { + sockfd: i32, + msg: Platform::RawMutPointer>, + flags: ReceiveFlags, + }, Bind { sockfd: i32, sockaddr: Platform::RawConstPointer, @@ -2510,6 +2515,7 @@ impl SyscallRequest { Sysno::sendto => sys_req!(Sendto { sockfd, buf:*, len, flags, addr:*, addrlen }), Sysno::sendmsg => sys_req!(Sendmsg { sockfd, msg:*, flags }), Sysno::recvfrom => sys_req!(Recvfrom { sockfd, buf:*, len, flags, addr:*, addrlen:*, }), + Sysno::recvmsg => sys_req!(Recvmsg { sockfd, msg:*, flags }), Sysno::bind => sys_req!(Bind { sockfd, sockaddr:*, addrlen }), Sysno::listen => sys_req!(Listen { sockfd, backlog }), Sysno::setsockopt => sys_req!(Setsockopt { diff --git a/litebox_shim_linux/src/lib.rs b/litebox_shim_linux/src/lib.rs index 2834f7b72..f89b538ec 100644 --- a/litebox_shim_linux/src/lib.rs +++ b/litebox_shim_linux/src/lib.rs @@ -712,6 +712,7 @@ impl Task { addr, addrlen, } => self.sys_recvfrom(sockfd, buf, len, flags, addr, addrlen), + SyscallRequest::Recvmsg { sockfd, msg, flags } => self.sys_recvmsg(sockfd, msg, flags), SyscallRequest::Bind { sockfd, sockaddr, diff --git a/litebox_shim_linux/src/syscalls/net.rs b/litebox_shim_linux/src/syscalls/net.rs index 44e24a927..b6b6125e1 100644 --- a/litebox_shim_linux/src/syscalls/net.rs +++ b/litebox_shim_linux/src/syscalls/net.rs @@ -1475,15 +1475,29 @@ impl Task { None }, )?; - buf.copy_from_slice(0, &recv_buf[..size.min(recv_buf.len())]) + let capped_size = size.min(recv_buf.len()); + buf.copy_from_slice(0, &recv_buf[..capped_size]) .ok_or(Errno::EFAULT)?; if let Some(src_addr) = source_addr && let Some(sock_ptr) = addr { write_sockaddr_to_user(src_addr, sock_ptr, addrlen)?; } - Ok(size) + + if flags.contains(ReceiveFlags::TRUNC) { + // the actual message size + Ok(size) + } else { + // the number of bytes copied + Ok(capped_size) + } } + /// Receive data from a socket. + /// + /// `source_addr` can be provided to receive the source address if available. + /// + /// On success, returns the number of bytes received. Note that for datagram sockets, + /// this may be larger than the provided buffer length as the excessive data will be truncated. fn do_recvfrom( &self, sockfd: u32, @@ -1527,17 +1541,128 @@ impl Task { )? }; - if !flags.contains(ReceiveFlags::TRUNC) { - let len = buf.len(); - assert!(size <= len, "{size} should be smaller than {len}"); - } - if let (Some(source_addr), Some(addr)) = (source_addr, addr) { *source_addr = Some(addr); } Ok(size) } + /// Handle syscall `recvmsg` + pub(crate) fn sys_recvmsg( + &self, + fd: i32, + msg_ptr: MutPtr>, + flags: ReceiveFlags, + ) -> Result { + const MAX_LEN: usize = 65536; + + let Ok(sockfd) = u32::try_from(fd) else { + return Err(Errno::EBADF); + }; + let msg = msg_ptr.read_at_offset(0).ok_or(Errno::EFAULT)?; + + // Copy fields out of the packed struct to avoid unaligned references. + let msg_name = msg.msg_name; + let msg_iov = msg.msg_iov; + let msg_iovlen = msg.msg_iovlen; + let msg_controllen = msg.msg_controllen; + + if msg_controllen != 0 { + log_unsupported!("ancillary data is not supported"); + } + if msg_iovlen == 0 || msg_iovlen > 1024 { + return Err(Errno::EINVAL); + } + + let iovs = msg_iov.to_owned_slice(msg_iovlen).ok_or(Errno::EFAULT)?; + + // Compute total buffer capacity across all non-empty iovecs, capped at MAX_LEN. + let total_iov_capacity: usize = iovs + .iter() + .map(|iov| iov.iov_len) + .fold(0usize, usize::saturating_add) + .min(MAX_LEN); + + if total_iov_capacity == 0 { + return Ok(0); + } + + // Perform a single recv into a contiguous buffer. + let want_source = msg_name.as_usize() != 0; + let mut source_addr = None; + let mut ret_flags = ReceiveFlags::empty(); + + // Heap-allocate the recv buffer to avoid stack overflow for large iovecs. + let mut buffer = alloc::vec![0u8; total_iov_capacity]; + let recv_buf = &mut buffer[..]; + let size = self.do_recvfrom( + sockfd, + recv_buf, + flags, + if want_source { + Some(&mut source_addr) + } else { + None + }, + )?; + + // Set MSG_TRUNC if the received message was larger than the total buffer. + if size > total_iov_capacity { + ret_flags |= ReceiveFlags::TRUNC; + } + + // Scatter the received data across iovecs sequentially. + let data_to_copy = size.min(total_iov_capacity); + let mut offset = 0usize; + for iov in &iovs { + if offset >= data_to_copy { + break; + } + if iov.iov_len == 0 { + continue; + } + let chunk = (data_to_copy - offset).min(iov.iov_len); + iov.iov_base + .copy_from_slice(0, &recv_buf[offset..offset + chunk]) + .ok_or(Errno::EFAULT)?; + offset += chunk; + } + + let total_received = if flags.contains(ReceiveFlags::TRUNC) { + // the actual message size + size + } else { + // the number of bytes copied + size.min(total_iov_capacity) + }; + + // Write back source address if requested. + if want_source { + let addrlen_ptr = MutPtr::::from_usize( + msg_ptr.as_usize() + + core::mem::offset_of!( + litebox_common_linux::UserMsgHdr, + msg_namelen + ), + ); + if let Some(src_addr) = source_addr { + let addr_ptr = MutPtr::::from_usize(msg_name.as_usize()); + write_sockaddr_to_user(src_addr, addr_ptr, addrlen_ptr)?; + } else { + // No source address (e.g. connected stream socket) — zero out msg_namelen. + let _ = addrlen_ptr.write_at_offset(0, 0u32); + } + } + + // Write back msg_flags with any status flags (e.g. MSG_TRUNC). + let flags_offset = + core::mem::offset_of!(litebox_common_linux::UserMsgHdr, msg_flags); + let flags_ptr = MutPtr::::from_usize(msg_ptr.as_usize() + flags_offset); + let _ = flags_ptr.write_at_offset(0, ret_flags); + + Ok(total_received) + } + pub(crate) fn sys_setsockopt( &self, sockfd: i32, @@ -1842,6 +1967,13 @@ impl Task { flags: 2, }) } + SocketcallType::Recvmsg => { + parse_socketcall_args!(3 => sys_recvmsg { + sockfd: 0, + msg: [ 1 ], + flags: 2, + }) + } _ => { log_unsupported!("socketcall type {socketcall_type:?} is not supported"); Err(Errno::EINVAL) @@ -1862,6 +1994,7 @@ mod tests { AddressFamily, ReceiveFlags, SendFlags, SockFlags, SockType, SocketOption, SocketOptionName, TcpOption, errno::Errno, }; + use zerocopy::FromZeros as _; use super::SocketAddress; use crate::{ConstPtr, MutPtr, syscalls::tests::init_platform}; @@ -1982,7 +2115,7 @@ mod tests { ]) .stdout(std::process::Stdio::piped()) .output(), - "recvfrom" => std::process::Command::new("sh") + "recvfrom" | "recvmsg" => std::process::Command::new("sh") .args([ "-c", &alloc::format!( @@ -2052,7 +2185,6 @@ mod tests { }, ]; let hdr = { - use zerocopy::FromZeros as _; let mut h = litebox_common_linux::UserMsgHdr::::new_zeroed(); h.msg_iov = ConstPtr::from_usize(iovec.as_ptr() as usize); h.msg_iovlen = iovec.len(); @@ -2070,7 +2202,7 @@ mod tests { let stdout = alloc::string::String::from_utf8_lossy(&output.stdout); assert_eq!(stdout, alloc::format!("{buf1}{buf2}")); } - "recvfrom" => { + "recvfrom" | "recvmsg" => { if is_nonblocking { epoll_add(task, epfd, client_fd, litebox::event::Events::IN); let mut events = [litebox_common_linux::EpollEvent { events: 0, data: 0 }; 2]; @@ -2082,18 +2214,30 @@ mod tests { } } let mut recv_buf = [0u8; 48]; - let n = task - .do_recvfrom( - client_fd, - &mut recv_buf, - if test_trunc { - ReceiveFlags::TRUNC - } else { - ReceiveFlags::empty() - }, - None, - ) - .expect("Failed to receive data"); + let flags = if test_trunc { + ReceiveFlags::TRUNC + } else { + ReceiveFlags::empty() + }; + let n = match option { + "recvfrom" => task + .do_recvfrom(client_fd, &mut recv_buf, flags, None) + .expect("Failed to receive data"), + "recvmsg" => { + let iovec = [litebox_common_linux::IoVec { + iov_base: MutPtr::from_usize(recv_buf.as_mut_ptr().expose_provenance()), + iov_len: recv_buf.len(), + }]; + let mut msg_hdr = + litebox_common_linux::UserMsgHdr::::new_zeroed(); + msg_hdr.msg_iov = ConstPtr::from_usize(iovec.as_ptr() as usize); + msg_hdr.msg_iovlen = iovec.len(); + let msg_ptr = MutPtr::from_usize(&raw mut msg_hdr as usize); + task.sys_recvmsg(i32::try_from(client_fd).unwrap(), msg_ptr, flags) + .expect("failed to recvmsg") + } + _ => unreachable!(), + }; if test_trunc { assert!(recv_buf.iter().all(|&b| b == 0)); // buf remains unchanged } else { @@ -2109,14 +2253,24 @@ mod tests { close_socket(task, server); } - fn test_tcp_socket_with_external_client( - port: u16, - is_nonblocking: bool, - test_trunc: bool, - option: &'static str, - ) { + fn test_tcp_socket_with_external_client(port: u16, is_nonblocking: bool, test_trunc: bool) { let task = init_platform(Some(TUN_DEVICE_NAME)); - test_tcp_socket_as_server(&task, TUN_IP_ADDR, port, is_nonblocking, test_trunc, option); + test_tcp_socket_as_server( + &task, + TUN_IP_ADDR, + port, + is_nonblocking, + test_trunc, + "recvfrom", + ); + test_tcp_socket_as_server( + &task, + TUN_IP_ADDR, + port, + is_nonblocking, + test_trunc, + "recvmsg", + ); } fn test_tcp_socket_send(is_nonblocking: bool, test_trunc: bool) { @@ -2151,17 +2305,17 @@ mod tests { #[test] fn test_tun_blocking_recvfrom_tcp_socket() { - test_tcp_socket_with_external_client(SERVER_PORT, false, false, "recvfrom"); + test_tcp_socket_with_external_client(SERVER_PORT, false, false); } #[test] fn test_tun_nonblocking_recvfrom_tcp_socket() { - test_tcp_socket_with_external_client(SERVER_PORT, true, false, "recvfrom"); + test_tcp_socket_with_external_client(SERVER_PORT, true, false); } #[test] fn test_tun_blocking_recvfrom_tcp_socket_with_truncation() { - test_tcp_socket_with_external_client(SERVER_PORT, false, true, "recvfrom"); + test_tcp_socket_with_external_client(SERVER_PORT, false, true); } #[test] diff --git a/litebox_shim_linux/src/syscalls/unix.rs b/litebox_shim_linux/src/syscalls/unix.rs index 6aec592d8..a75cf3d28 100644 --- a/litebox_shim_linux/src/syscalls/unix.rs +++ b/litebox_shim_linux/src/syscalls/unix.rs @@ -864,55 +864,29 @@ impl WriteEnd { } } impl ReadEnd { - /// Attempts to read datagram messages without blocking. + /// Attempts to read a single datagram message without blocking. /// - /// Reads multiple messages from the same source address until the buffer - /// is full or a message from a different source is encountered. + /// Reads exactly one message, preserving message boundaries. If the buffer + /// is smaller than the message, the excess data is discarded (truncated). + /// Returns the original message size (which may exceed `buf.len()`). fn try_read( &self, - mut buf: &mut [u8], - source_addr: Option<&mut Option>, + buf: &mut [u8], + mut source_addr: Option<&mut Option>, ) -> Result> { - let mut src = None; - let mut total_read = 0; - let mut stop = false; - while !buf.is_empty() { - let n = match self.peek_and_consume_one(|msg| { - if src.as_ref().is_some_and(|addr| *addr != msg.source) { - stop = true; - return Ok((false, 0)); - } - if src.is_none() { - src.replace(msg.source.clone()); - } - if buf.len() >= msg.data.len() { - buf[..msg.data.len()].copy_from_slice(&msg.data); - Ok((true, msg.data.len())) - } else { - buf.copy_from_slice(&msg.data[..buf.len()]); - msg.data = msg.data.split_off(buf.len()); - Ok((false, buf.len())) - } - }) { - Ok(0) if stop => break, - Ok(n) => n, - Err(e) => { - if total_read > 0 { - break; - } - return match e { - Errno::EAGAIN => Err(TryOpError::TryAgain), - other => Err(TryOpError::Other(other)), - }; - } - }; - total_read += n; - buf = &mut buf[n..]; - } - if let (Some(src), Some(source_addr)) = (src, source_addr) { - *source_addr = Some(src); - } - Ok(total_read) + self.peek_and_consume_one(|msg| { + let copy_len = buf.len().min(msg.data.len()); + buf[..copy_len].copy_from_slice(&msg.data[..copy_len]); + if let Some(source_addr) = source_addr.as_deref_mut() { + *source_addr = Some(msg.source.clone()); + } + // Always consume the entire message to preserve boundaries. + Ok((true, msg.data.len())) + }) + .map_err(|e| match e { + Errno::EAGAIN => TryOpError::TryAgain, + other => TryOpError::Other(other), + }) } } @@ -1265,7 +1239,7 @@ impl UnixSocket { flags: ReceiveFlags, source_addr: Option<&mut Option>, ) -> Result { - let supported_flags = ReceiveFlags::DONTWAIT; + let supported_flags = ReceiveFlags::DONTWAIT | ReceiveFlags::TRUNC; if flags.intersects(supported_flags.complement()) { log_unsupported!("Unsupported recvfrom flags: {:?}", flags); return Err(Errno::EINVAL);