Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion litebox_common_linux/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,7 @@ pub struct UserMsgHdr<Platform: litebox::platform::RawPointerProvider> {
/// number of bytes of ancillary data
pub msg_controllen: usize,
/// flags on received message
pub msg_flags: SendFlags,
pub msg_flags: ReceiveFlags,
/// Explicit trailing padding to match the 4-byte gap after `msg_flags` in
/// Linux's naturally-aligned `struct user_msghdr` on 64-bit (total size 56).
#[cfg(target_pointer_width = "64")]
Expand Down Expand Up @@ -2035,6 +2035,11 @@ pub enum SyscallRequest<Platform: litebox::platform::RawPointerProvider> {
addr: Option<Platform::RawMutPointer<u8>>,
addrlen: Platform::RawMutPointer<u32>,
},
Recvmsg {
sockfd: i32,
msg: Platform::RawMutPointer<UserMsgHdr<Platform>>,
flags: ReceiveFlags,
},
Bind {
sockfd: i32,
sockaddr: Platform::RawConstPointer<u8>,
Expand Down Expand Up @@ -2510,6 +2515,7 @@ impl<Platform: litebox::platform::RawPointerProvider> SyscallRequest<Platform> {
Sysno::sendto => sys_req!(Sendto { sockfd, buf:*, len, flags, addr:*, addrlen }),
Sysno::sendmsg => sys_req!(Sendmsg { sockfd, msg:*, flags }),
Sysno::recvfrom => sys_req!(Recvfrom { sockfd, buf:*, len, flags, addr:*, addrlen:*, }),
Sysno::recvmsg => sys_req!(Recvmsg { sockfd, msg:*, flags }),
Sysno::bind => sys_req!(Bind { sockfd, sockaddr:*, addrlen }),
Sysno::listen => sys_req!(Listen { sockfd, backlog }),
Sysno::setsockopt => sys_req!(Setsockopt {
Expand Down
1 change: 1 addition & 0 deletions litebox_shim_linux/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ impl<FS: ShimFS> Task<FS> {
addr,
addrlen,
} => self.sys_recvfrom(sockfd, buf, len, flags, addr, addrlen),
SyscallRequest::Recvmsg { sockfd, msg, flags } => self.sys_recvmsg(sockfd, msg, flags),
SyscallRequest::Bind {
sockfd,
sockaddr,
Expand Down
218 changes: 186 additions & 32 deletions litebox_shim_linux/src/syscalls/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1475,15 +1475,29 @@ impl<FS: ShimFS> Task<FS> {
None
},
)?;
buf.copy_from_slice(0, &recv_buf[..size.min(recv_buf.len())])
let capped_size = size.min(recv_buf.len());
buf.copy_from_slice(0, &recv_buf[..capped_size])
.ok_or(Errno::EFAULT)?;
if let Some(src_addr) = source_addr
&& let Some(sock_ptr) = addr
{
write_sockaddr_to_user(src_addr, sock_ptr, addrlen)?;
}
Ok(size)

if flags.contains(ReceiveFlags::TRUNC) {
// the actual message size
Ok(size)
} else {
// the number of bytes copied
Ok(capped_size)
}
}
/// Receive data from a socket.
///
/// `source_addr` can be provided to receive the source address if available.
///
/// On success, returns the number of bytes received. Note that for datagram sockets,
/// this may be larger than the provided buffer length as the excessive data will be truncated.
fn do_recvfrom(
&self,
sockfd: u32,
Expand Down Expand Up @@ -1527,17 +1541,128 @@ impl<FS: ShimFS> Task<FS> {
)?
};

if !flags.contains(ReceiveFlags::TRUNC) {
let len = buf.len();
assert!(size <= len, "{size} should be smaller than {len}");
}

if let (Some(source_addr), Some(addr)) = (source_addr, addr) {
*source_addr = Some(addr);
}
Ok(size)
}

/// Handle syscall `recvmsg`
pub(crate) fn sys_recvmsg(
&self,
fd: i32,
msg_ptr: MutPtr<litebox_common_linux::UserMsgHdr<Platform>>,
flags: ReceiveFlags,
) -> Result<usize, Errno> {
const MAX_LEN: usize = 65536;

let Ok(sockfd) = u32::try_from(fd) else {
return Err(Errno::EBADF);
};
let msg = msg_ptr.read_at_offset(0).ok_or(Errno::EFAULT)?;

// Copy fields out of the packed struct to avoid unaligned references.
let msg_name = msg.msg_name;
let msg_iov = msg.msg_iov;
let msg_iovlen = msg.msg_iovlen;
let msg_controllen = msg.msg_controllen;

if msg_controllen != 0 {
log_unsupported!("ancillary data is not supported");
}
if msg_iovlen == 0 || msg_iovlen > 1024 {
return Err(Errno::EINVAL);
}

let iovs = msg_iov.to_owned_slice(msg_iovlen).ok_or(Errno::EFAULT)?;

// Compute total buffer capacity across all non-empty iovecs, capped at MAX_LEN.
let total_iov_capacity: usize = iovs
.iter()
.map(|iov| iov.iov_len)
.fold(0usize, usize::saturating_add)
.min(MAX_LEN);

if total_iov_capacity == 0 {
return Ok(0);
}

// Perform a single recv into a contiguous buffer.
let want_source = msg_name.as_usize() != 0;
let mut source_addr = None;
let mut ret_flags = ReceiveFlags::empty();

// Heap-allocate the recv buffer to avoid stack overflow for large iovecs.
let mut buffer = alloc::vec![0u8; total_iov_capacity];
let recv_buf = &mut buffer[..];
let size = self.do_recvfrom(
sockfd,
recv_buf,
flags,
if want_source {
Some(&mut source_addr)
} else {
None
},
)?;

// Set MSG_TRUNC if the received message was larger than the total buffer.
if size > total_iov_capacity {
ret_flags |= ReceiveFlags::TRUNC;
}

// Scatter the received data across iovecs sequentially.
let data_to_copy = size.min(total_iov_capacity);
let mut offset = 0usize;
for iov in &iovs {
if offset >= data_to_copy {
break;
}
if iov.iov_len == 0 {
continue;
}
let chunk = (data_to_copy - offset).min(iov.iov_len);
iov.iov_base
.copy_from_slice(0, &recv_buf[offset..offset + chunk])
.ok_or(Errno::EFAULT)?;
offset += chunk;
}

let total_received = if flags.contains(ReceiveFlags::TRUNC) {
// the actual message size
size
} else {
// the number of bytes copied
size.min(total_iov_capacity)
};

// Write back source address if requested.
if want_source {
let addrlen_ptr = MutPtr::<u32>::from_usize(
msg_ptr.as_usize()
+ core::mem::offset_of!(
litebox_common_linux::UserMsgHdr<Platform>,
msg_namelen
),
);
if let Some(src_addr) = source_addr {
let addr_ptr = MutPtr::<u8>::from_usize(msg_name.as_usize());
write_sockaddr_to_user(src_addr, addr_ptr, addrlen_ptr)?;
} else {
// No source address (e.g. connected stream socket) — zero out msg_namelen.
let _ = addrlen_ptr.write_at_offset(0, 0u32);
}
}

// Write back msg_flags with any status flags (e.g. MSG_TRUNC).
let flags_offset =
core::mem::offset_of!(litebox_common_linux::UserMsgHdr<Platform>, msg_flags);
let flags_ptr = MutPtr::<ReceiveFlags>::from_usize(msg_ptr.as_usize() + flags_offset);
let _ = flags_ptr.write_at_offset(0, ret_flags);

Ok(total_received)
}

pub(crate) fn sys_setsockopt(
&self,
sockfd: i32,
Expand Down Expand Up @@ -1842,6 +1967,13 @@ impl<FS: ShimFS> Task<FS> {
flags: 2,
})
}
SocketcallType::Recvmsg => {
parse_socketcall_args!(3 => sys_recvmsg {
sockfd: 0,
msg: [ 1 ],
flags: 2,
})
}
_ => {
log_unsupported!("socketcall type {socketcall_type:?} is not supported");
Err(Errno::EINVAL)
Expand All @@ -1862,6 +1994,7 @@ mod tests {
AddressFamily, ReceiveFlags, SendFlags, SockFlags, SockType, SocketOption,
SocketOptionName, TcpOption, errno::Errno,
};
use zerocopy::FromZeros as _;

use super::SocketAddress;
use crate::{ConstPtr, MutPtr, syscalls::tests::init_platform};
Expand Down Expand Up @@ -1982,7 +2115,7 @@ mod tests {
])
.stdout(std::process::Stdio::piped())
.output(),
"recvfrom" => std::process::Command::new("sh")
"recvfrom" | "recvmsg" => std::process::Command::new("sh")
.args([
"-c",
&alloc::format!(
Expand Down Expand Up @@ -2052,7 +2185,6 @@ mod tests {
},
];
let hdr = {
use zerocopy::FromZeros as _;
let mut h = litebox_common_linux::UserMsgHdr::<crate::Platform>::new_zeroed();
h.msg_iov = ConstPtr::from_usize(iovec.as_ptr() as usize);
h.msg_iovlen = iovec.len();
Expand All @@ -2070,7 +2202,7 @@ mod tests {
let stdout = alloc::string::String::from_utf8_lossy(&output.stdout);
assert_eq!(stdout, alloc::format!("{buf1}{buf2}"));
}
"recvfrom" => {
"recvfrom" | "recvmsg" => {
if is_nonblocking {
epoll_add(task, epfd, client_fd, litebox::event::Events::IN);
let mut events = [litebox_common_linux::EpollEvent { events: 0, data: 0 }; 2];
Expand All @@ -2082,18 +2214,30 @@ mod tests {
}
}
let mut recv_buf = [0u8; 48];
let n = task
.do_recvfrom(
client_fd,
&mut recv_buf,
if test_trunc {
ReceiveFlags::TRUNC
} else {
ReceiveFlags::empty()
},
None,
)
.expect("Failed to receive data");
let flags = if test_trunc {
ReceiveFlags::TRUNC
} else {
ReceiveFlags::empty()
};
let n = match option {
"recvfrom" => task
.do_recvfrom(client_fd, &mut recv_buf, flags, None)
.expect("Failed to receive data"),
"recvmsg" => {
let iovec = [litebox_common_linux::IoVec {
iov_base: MutPtr::from_usize(recv_buf.as_mut_ptr().expose_provenance()),
iov_len: recv_buf.len(),
}];
let mut msg_hdr =
litebox_common_linux::UserMsgHdr::<crate::Platform>::new_zeroed();
msg_hdr.msg_iov = ConstPtr::from_usize(iovec.as_ptr() as usize);
msg_hdr.msg_iovlen = iovec.len();
let msg_ptr = MutPtr::from_usize(&raw mut msg_hdr as usize);
task.sys_recvmsg(i32::try_from(client_fd).unwrap(), msg_ptr, flags)
.expect("failed to recvmsg")
}
_ => unreachable!(),
};
if test_trunc {
assert!(recv_buf.iter().all(|&b| b == 0)); // buf remains unchanged
} else {
Expand All @@ -2109,14 +2253,24 @@ mod tests {
close_socket(task, server);
}

fn test_tcp_socket_with_external_client(
port: u16,
is_nonblocking: bool,
test_trunc: bool,
option: &'static str,
) {
fn test_tcp_socket_with_external_client(port: u16, is_nonblocking: bool, test_trunc: bool) {
let task = init_platform(Some(TUN_DEVICE_NAME));
test_tcp_socket_as_server(&task, TUN_IP_ADDR, port, is_nonblocking, test_trunc, option);
test_tcp_socket_as_server(
&task,
TUN_IP_ADDR,
port,
is_nonblocking,
test_trunc,
"recvfrom",
);
test_tcp_socket_as_server(
&task,
TUN_IP_ADDR,
port,
is_nonblocking,
test_trunc,
"recvmsg",
);
}

fn test_tcp_socket_send(is_nonblocking: bool, test_trunc: bool) {
Expand Down Expand Up @@ -2151,17 +2305,17 @@ mod tests {

#[test]
fn test_tun_blocking_recvfrom_tcp_socket() {
test_tcp_socket_with_external_client(SERVER_PORT, false, false, "recvfrom");
test_tcp_socket_with_external_client(SERVER_PORT, false, false);
}

#[test]
fn test_tun_nonblocking_recvfrom_tcp_socket() {
test_tcp_socket_with_external_client(SERVER_PORT, true, false, "recvfrom");
test_tcp_socket_with_external_client(SERVER_PORT, true, false);
}

#[test]
fn test_tun_blocking_recvfrom_tcp_socket_with_truncation() {
test_tcp_socket_with_external_client(SERVER_PORT, false, true, "recvfrom");
test_tcp_socket_with_external_client(SERVER_PORT, false, true);
}

#[test]
Expand Down
Loading
Loading