diff --git a/Cargo.toml b/Cargo.toml index ec69b0e9..5db3590f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,21 +1,34 @@ [package] -name = "socket2" -version = "0.4.0-dev" -authors = ["Alex Crichton "] -license = "MIT/Apache-2.0" -readme = "README.md" -repository = "https://github.com/alexcrichton/socket2-rs" -homepage = "https://github.com/alexcrichton/socket2-rs" +name = "socket2" +version = "0.4.0-dev" +authors = ["Alex Crichton "] +license = "MIT/Apache-2.0" +readme = "README.md" +repository = "https://github.com/rust-lang/socket2-rs" +homepage = "https://github.com/rust-lang/socket2-rs" +documentation = "https://docs.rs/socket2" description = """ Utilities for handling networking sockets with a maximal amount of configuration possible intended. """ -edition = "2018" +keywords = ["io", "socket", "network"] +categories = ["api-bindings", "network-programming", "web-programming"] +edition = "2018" +include = [ + "Cargo.toml", + "LICENSE-APACHE", + "LICENSE-MIT", + "README.md", + "src/**/*.rs", +] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] +[package.metadata.playground] +features = ["all"] + [target."cfg(windows)".dependencies.winapi] version = "0.3.3" features = ["handleapi", "ws2def", "ws2ipdef", "ws2tcpip", "minwindef"] diff --git a/src/socket.rs b/src/socket.rs index b31392a3..986730b1 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -43,39 +43,6 @@ use crate::{Domain, Protocol, SockAddr, Type}; /// /// # Examples /// -/// Creating a new socket setting all advisable flags. -/// -#[cfg_attr(feature = "all", doc = "```")] // Protocol::cloexec requires the `all` feature. -#[cfg_attr(not(feature = "all"), doc = "```ignore")] -/// # fn main() -> std::io::Result<()> { -/// use socket2::{Protocol, Domain, Type, Socket}; -/// -/// let domain = Domain::IPV4; -/// let ty = Type::STREAM; -/// let protocol = Protocol::TCP; -/// -/// // On platforms that support it set `SOCK_CLOEXEC`. -/// #[cfg(any(target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "linux", target_os = "netbsd", target_os = "openbsd"))] -/// let ty = ty.cloexec(); -/// -/// let socket = Socket::new(domain, ty, Some(protocol))?; -/// -/// // On platforms that don't support `SOCK_CLOEXEC`, use `FD_CLOEXEC`. -/// #[cfg(all(not(windows), not(any(target_os = "android", target_os = "dragonfly", target_os = "freebsd", target_os = "linux", target_os = "netbsd", target_os = "openbsd"))))] -/// socket.set_cloexec()?; -/// -/// // On macOS and iOS set `NOSIGPIPE`. -/// #[cfg(target_vendor = "apple")] -/// socket.set_nosigpipe()?; -/// -/// // On windows set `HANDLE_FLAG_INHERIT`. -/// #[cfg(windows)] -/// socket.set_no_inherit()?; -/// # drop(socket); -/// # Ok(()) -/// # } -/// ``` -/// /// ```no_run /// # fn main() -> std::io::Result<()> { /// use std::net::{SocketAddr, TcpListener}; @@ -101,6 +68,55 @@ pub struct Socket { } impl Socket { + /// Creates a new socket ready to be configured. + /// + /// This function corresponds to `socket(2)` on Unix and `WSASocketW` on + /// Windows and creates a new socket. Unlike `Socket::new_raw` this sets the + /// most commonly used flags in the fastest possible way. + /// + /// On Unix this sets the `CLOEXEC` flag. Furthermore on macOS and iOS + /// `NOSIGPIPE` is set. + /// + /// On Windows the `HANDLE_FLAG_INHERIT` is set to zero. + pub fn new(domain: Domain, ty: Type, protocol: Option) -> io::Result { + // On platforms that support it set `SOCK_CLOEXEC`. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + ))] + let ty = ty.cloexec(); + + // On windows set `WSA_FLAG_NO_HANDLE_INHERIT`. + #[cfg(windows)] + let ty = ty.no_inherit(); + + let socket = Socket::new_raw(domain, ty, protocol)?; + + // On platforms that don't support `SOCK_CLOEXEC`, use `FD_CLOEXEC`. + #[cfg(all( + not(windows), + not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd" + )) + ))] + socket.set_cloexec()?; + + // On macOS and iOS set `NOSIGPIPE`. + #[cfg(target_vendor = "apple")] + socket._set_nosigpipe()?; + + Ok(socket) + } + /// Creates a new socket ready to be configured. /// /// This function corresponds to `socket(2)` on Unix and `WSASocketW` on @@ -123,7 +139,7 @@ impl Socket { /// /// See the `Socket` documentation for a full example of setting all the /// above mentioned flags. - pub fn new(domain: Domain, ty: Type, protocol: Option) -> io::Result { + pub fn new_raw(domain: Domain, ty: Type, protocol: Option) -> io::Result { let protocol = protocol.map(|p| p.0).unwrap_or(0); sys::socket(domain.0, ty.0, protocol).map(|inner| Socket { inner }) } @@ -198,26 +214,44 @@ impl Socket { sys::accept(self.inner).map(|(inner, addr)| (Socket { inner }, addr)) } - /// Returns the socket address of the local half of this TCP connection. + /// Returns the socket address of the local half of this socket. + /// + /// # Notes + /// + /// Depending on the OS this may return an error if the socket is not + /// [bound]. + /// + /// [bound]: Socket::bind pub fn local_addr(&self) -> io::Result { - self.inner().local_addr() + sys::getsockname(self.inner) } - /// Returns the socket address of the remote peer of this TCP connection. + /// Returns the socket address of the remote peer of this socket. + /// + /// # Notes + /// + /// This returns an error if the socket is not [`connect`ed]. + /// + /// [`connect`ed]: Socket::connect pub fn peer_addr(&self) -> io::Result { - self.inner().peer_addr() + sys::getpeername(self.inner) } /// Creates a new independently owned handle to the underlying socket. /// - /// The returned `TcpStream` is a reference to the same stream that this - /// object references. Both handles will read and write the same stream of - /// data, and options set on one stream will be propagated to the other - /// stream. + /// # Notes + /// + /// On Unix this uses `F_DUPFD_CLOEXEC` and thus sets the `FD_CLOEXEC` on + /// the returned socket. + /// + /// On Windows this uses `WSA_FLAG_NO_HANDLE_INHERIT` setting inheriting to + /// false. + /// + /// On Windows this can **not** be used function cannot be used on a + /// QOS-enabled socket, see + /// https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsaduplicatesocketw. pub fn try_clone(&self) -> io::Result { - self.inner() - .try_clone() - .map(|s| Socket { inner: s.inner() }) + sys::try_clone(self.inner).map(|inner| Socket { inner }) } /// Get the value of the `SO_ERROR` option on this socket. @@ -226,15 +260,19 @@ impl Socket { /// the field in the process. This can be useful for checking errors between /// calls. pub fn take_error(&self) -> io::Result> { - self.inner().take_error() + sys::take_error(self.inner) } /// Moves this TCP stream into or out of nonblocking mode. /// - /// On Unix this corresponds to calling fcntl, and on Windows this - /// corresponds to calling ioctlsocket. + /// # Notes + /// + /// On Unix this corresponds to calling `fcntl` (un)setting `O_NONBLOCK`. + /// + /// On Windows this corresponds to calling `ioctlsocket` (un)setting + /// `FIONBIO`. pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.inner().set_nonblocking(nonblocking) + sys::set_nonblocking(self.inner, nonblocking) } /// Shuts down the read, write, or both halves of this connection. @@ -242,7 +280,7 @@ impl Socket { /// This function will cause all pending and future I/O on the specified /// portions to return immediately with an appropriate value. pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - self.inner().shutdown(how) + sys::shutdown(self.inner, how) } /// Receives data on the socket from the remote address to which it is diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 0c984fc7..58cc48fc 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -9,6 +9,7 @@ #[cfg(not(target_os = "redox"))] use std::io::{IoSlice, IoSliceMut}; use std::io::{Read, Write}; +use std::mem::{self, size_of_val, MaybeUninit}; use std::net::Shutdown; use std::net::{self, Ipv4Addr, Ipv6Addr}; #[cfg(feature = "all")] @@ -19,9 +20,8 @@ use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; use std::path::Path; #[cfg(feature = "all")] use std::ptr; -use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use std::{cmp, fmt, io, mem}; +use std::{cmp, fmt, io}; use libc::{self, c_void, in6_addr, in_addr, ssize_t}; @@ -324,13 +324,57 @@ pub(crate) fn listen(fd: SysSocket, backlog: i32) -> io::Result<()> { pub(crate) fn accept(fd: SysSocket) -> io::Result<(SysSocket, SockAddr)> { // Safety: zeroed `sockaddr_storage` is valid. let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; - let mut len = mem::size_of_val(&storage) as socklen_t; + let mut len = size_of_val(&storage) as socklen_t; syscall!(accept(fd, &mut storage as *mut _ as *mut _, &mut len)).map(|fd| { let addr = unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }; (fd, addr) }) } +pub(crate) fn getsockname(fd: SysSocket) -> io::Result { + let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let mut len = size_of_val(&storage) as libc::socklen_t; + syscall!(getsockname(fd, &mut storage as *mut _ as *mut _, &mut len,)) + .map(|_| unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }) +} + +pub(crate) fn getpeername(fd: SysSocket) -> io::Result { + let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; + let mut len = size_of_val(&storage) as libc::socklen_t; + syscall!(getpeername(fd, &mut storage as *mut _ as *mut _, &mut len,)) + .map(|_| unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }) +} + +pub(crate) fn try_clone(fd: SysSocket) -> io::Result { + syscall!(fcntl(fd, libc::F_DUPFD_CLOEXEC, 0)) +} + +pub(crate) fn take_error(fd: SysSocket) -> io::Result> { + match unsafe { getsockopt::(fd, libc::SOL_SOCKET, libc::SO_ERROR) } { + Ok(0) => Ok(None), + Ok(errno) => Ok(Some(io::Error::from_raw_os_error(errno))), + Err(err) => Err(err), + } +} + +pub(crate) fn set_nonblocking(fd: SysSocket, nonblocking: bool) -> io::Result<()> { + if nonblocking { + fcntl_add(fd, libc::O_NONBLOCK) + } else { + fcntl_remove(fd, libc::O_NONBLOCK) + } +} + +pub(crate) fn shutdown(fd: SysSocket, how: Shutdown) -> io::Result<()> { + let how = match how { + Shutdown::Write => libc::SHUT_WR, + Shutdown::Read => libc::SHUT_RD, + Shutdown::Both => libc::SHUT_RDWR, + }; + syscall!(shutdown(fd, how)).map(|_| ()) +} + +/// Unix only API. impl crate::Socket { /// Accept a new incoming connection from this listener. /// @@ -381,13 +425,25 @@ impl crate::Socket { fcntl_add(self.inner, libc::FD_CLOEXEC) } - /// Sets `SO_NOSIGPIPE` to one. - #[cfg(target_vendor = "apple")] + /// Sets `SO_NOSIGPIPE` on the socket. + /// + /// # Notes + /// + /// Only supported on Apple platforms (`target_vendor = "apple"`). + #[cfg(all(feature = "all", target_vendor = "apple"))] pub fn set_nosigpipe(&self) -> io::Result<()> { + self._set_nosigpipe() + } + + // Because `set_nosigpipe` is behind the `all` feature flag we need a + // private version for `Socket::new`, which is always enabled. + #[cfg(target_vendor = "apple")] + pub(crate) fn _set_nosigpipe(&self) -> io::Result<()> { unsafe { setsockopt(self.inner, libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1i32) } } } +/// Add `flag` to the current set flags of `F_GETFD`. fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> { let previous = syscall!(fcntl(fd, libc::F_GETFD))?; let new = previous | flag; @@ -399,6 +455,37 @@ fn fcntl_add(fd: SysSocket, flag: c_int) -> io::Result<()> { } } +/// Remove `flag` to the current set flags of `F_GETFD`. +fn fcntl_remove(fd: SysSocket, flag: c_int) -> io::Result<()> { + let previous = syscall!(fcntl(fd, libc::F_GETFD))?; + let new = previous & !flag; + if new != previous { + syscall!(fcntl(fd, libc::F_SETFD, new)).map(|_| ()) + } else { + // Flag was already set. + Ok(()) + } +} + +/// Caller must ensure `T` is the correct type for `opt` and `val`. +unsafe fn getsockopt(fd: SysSocket, opt: c_int, val: c_int) -> io::Result { + let mut payload: MaybeUninit = MaybeUninit::uninit(); + let mut len = mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + fd, + opt, + val, + payload.as_mut_ptr().cast(), + &mut len, + )) + .map(|_| { + debug_assert_eq!(len as usize, mem::size_of::()); + // Safety: `getsockopt` initialised `payload` for us. + payload.assume_init() + }) +} + +/// Caller must ensure `T` is the correct type for `opt` and `val`. #[cfg(target_vendor = "apple")] unsafe fn setsockopt(fd: SysSocket, opt: c_int, val: c_int, payload: T) -> io::Result<()> where @@ -411,8 +498,8 @@ where val, payload, mem::size_of::() as libc::socklen_t, - ))?; - Ok(()) + )) + .map(|_| ()) } #[repr(transparent)] // Required during rewriting. @@ -421,91 +508,6 @@ pub struct Socket { } impl Socket { - pub fn local_addr(&self) -> io::Result { - let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; - let mut len = mem::size_of_val(&storage) as libc::socklen_t; - syscall!(getsockname( - self.fd, - &mut storage as *mut _ as *mut _, - &mut len, - ))?; - Ok(unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }) - } - - pub fn peer_addr(&self) -> io::Result { - let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() }; - let mut len = mem::size_of_val(&storage) as libc::socklen_t; - syscall!(getpeername( - self.fd, - &mut storage as *mut _ as *mut _, - &mut len, - ))?; - Ok(unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }) - } - - pub fn try_clone(&self) -> io::Result { - // implementation lifted from libstd - #[cfg(any(target_os = "android", target_os = "haiku"))] - use libc::F_DUPFD as F_DUPFD_CLOEXEC; - #[cfg(not(any(target_os = "android", target_os = "haiku")))] - use libc::F_DUPFD_CLOEXEC; - - static CLOEXEC_FAILED: AtomicBool = AtomicBool::new(false); - if !CLOEXEC_FAILED.load(Ordering::Relaxed) { - match syscall!(fcntl(self.fd, F_DUPFD_CLOEXEC, 0)) { - Ok(fd) => { - let fd = unsafe { Socket::from_raw_fd(fd) }; - if cfg!(target_os = "linux") { - set_cloexec(fd.as_raw_fd())?; - } - return Ok(fd); - } - Err(ref e) if e.raw_os_error() == Some(libc::EINVAL) => { - CLOEXEC_FAILED.store(true, Ordering::Relaxed); - } - Err(e) => return Err(e), - } - } - let fd = syscall!(fcntl(self.fd, libc::F_DUPFD, 0))?; - let fd = unsafe { Socket::from_raw_fd(fd) }; - set_cloexec(fd.as_raw_fd())?; - Ok(fd) - } - - pub fn take_error(&self) -> io::Result> { - unsafe { - let raw: c_int = self.getsockopt(libc::SOL_SOCKET, libc::SO_ERROR)?; - if raw == 0 { - Ok(None) - } else { - Ok(Some(io::Error::from_raw_os_error(raw as i32))) - } - } - } - - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - let previous = syscall!(fcntl(self.fd, libc::F_GETFL))?; - let new = if nonblocking { - previous | libc::O_NONBLOCK - } else { - previous & !libc::O_NONBLOCK - }; - if new != previous { - syscall!(fcntl(self.fd, libc::F_SETFL, new))?; - } - Ok(()) - } - - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - let how = match how { - Shutdown::Write => libc::SHUT_WR, - Shutdown::Read => libc::SHUT_RD, - Shutdown::Both => libc::SHUT_RDWR, - }; - syscall!(shutdown(self.fd, how))?; - Ok(()) - } - pub fn recv(&self, buf: &mut [u8], flags: c_int) -> io::Result { let n = syscall!(recv( self.fd, @@ -1047,10 +1049,10 @@ impl fmt::Debug for Socket { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut f = f.debug_struct("Socket"); f.field("fd", &self.fd); - if let Ok(addr) = self.local_addr() { + if let Ok(addr) = getsockname(self.fd) { f.field("local_addr", &addr); } - if let Ok(addr) = self.peer_addr() { + if let Ok(addr) = getpeername(self.fd) { f.field("peer_addr", &addr); } f.finish() @@ -1193,15 +1195,6 @@ fn max_len() -> usize { } } -fn set_cloexec(fd: c_int) -> io::Result<()> { - let previous = syscall!(fcntl(fd, libc::F_GETFD))?; - let new = previous | libc::FD_CLOEXEC; - if new != previous { - syscall!(fcntl(fd, libc::F_SETFD, new))?; - } - Ok(()) -} - fn dur2timeval(dur: Option) -> io::Result { match dur { Some(dur) => { diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 207780cb..d9d13c10 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -10,7 +10,7 @@ use std::cmp; use std::fmt; use std::io; use std::io::{IoSlice, IoSliceMut, Read, Write}; -use std::mem; +use std::mem::{self, size_of, size_of_val, MaybeUninit}; use std::net::Shutdown; use std::net::{self, Ipv4Addr, Ipv6Addr}; use std::os::windows::prelude::*; @@ -18,27 +18,26 @@ use std::ptr; use std::sync::Once; use std::time::Duration; -use winapi::ctypes::{c_char, c_ulong}; +use winapi::ctypes::{c_char, c_long, c_ulong}; use winapi::shared::in6addr::*; use winapi::shared::inaddr::*; use winapi::shared::minwindef::DWORD; +#[cfg(feature = "all")] use winapi::shared::ntdef::HANDLE; use winapi::shared::ws2def::{self, *}; use winapi::shared::ws2ipdef::*; +#[cfg(feature = "all")] use winapi::um::handleapi::SetHandleInformation; use winapi::um::processthreadsapi::GetCurrentProcessId; +#[cfg(feature = "all")] +use winapi::um::winbase; use winapi::um::winbase::INFINITE; -use winapi::um::winsock2 as sock; +use winapi::um::winsock2::{self as sock, u_long, SD_BOTH, SD_RECEIVE, SD_SEND}; -use crate::{RecvFlags, SockAddr}; +use crate::{RecvFlags, SockAddr, Type}; -const HANDLE_FLAG_INHERIT: DWORD = 0x00000001; const MSG_PEEK: c_int = 0x2; -const SD_BOTH: c_int = 2; -const SD_RECEIVE: c_int = 0; -const SD_SEND: c_int = 1; const SIO_KEEPALIVE_VALS: DWORD = 0x98000004; -const WSA_FLAG_OVERLAPPED: DWORD = 0x01; pub use winapi::ctypes::c_int; @@ -70,6 +69,7 @@ pub(crate) use winapi::um::ws2tcpip::socklen_t; /// Helper macro to execute a system call that returns an `io::Result`. macro_rules! syscall { ($fn: ident ( $($arg: expr),* $(,)* ), $err_test: path, $err_value: expr) => {{ + #[allow(unused_unsafe)] let res = unsafe { sock::$fn($($arg, )*) }; if $err_test(&res, &$err_value) { Err(io::Error::last_os_error()) @@ -87,6 +87,19 @@ impl_debug!( ws2def::AF_UNSPEC, // = 0. ); +/// Windows only API. +impl Type { + /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation. + /// Trying to mimic `Type::cloexec` on windows. + const NO_INHERIT: c_int = 1 << (size_of::()); + + /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket. + #[cfg(feature = "all")] + pub const fn no_inherit(self) -> Type { + Type(self.0 | Type::NO_INHERIT) + } +} + impl_debug!( crate::Type, ws2def::SOCK_STREAM, @@ -137,9 +150,17 @@ fn last_error() -> io::Error { // TODO: rename to `Socket` once the struct `Socket` is no longer used. pub(crate) type SysSocket = sock::SOCKET; -pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result { +pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result { init(); + // Check if we set our custom flag. + let flags = if ty & Type::NO_INHERIT != 0 { + ty = ty & !Type::NO_INHERIT; + sock::WSA_FLAG_NO_HANDLE_INHERIT + } else { + 0 + }; + syscall!( WSASocketW( family, @@ -147,7 +168,7 @@ pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result io::Result<()> { pub(crate) fn accept(socket: SysSocket) -> io::Result<(SysSocket, SockAddr)> { // Safety: zeroed `SOCKADDR_STORAGE` is valid. let mut storage: SOCKADDR_STORAGE = unsafe { mem::zeroed() }; - let mut len = mem::size_of_val(&storage) as c_int; + let mut len = size_of_val(&storage) as c_int; syscall!( accept(socket, &mut storage as *mut _ as *mut _, &mut len), PartialEq::eq, @@ -181,12 +202,113 @@ pub(crate) fn accept(socket: SysSocket) -> io::Result<(SysSocket, SockAddr)> { }) } +pub(crate) fn getsockname(socket: SysSocket) -> io::Result { + // Safety: zeroed `SOCKADDR_STORAGE` is valid. + let mut storage: SOCKADDR_STORAGE = unsafe { mem::zeroed() }; + let mut len = size_of_val(&storage) as c_int; + syscall!( + getsockname(socket, &mut storage as *mut _ as *mut _, &mut len), + PartialEq::eq, + sock::SOCKET_ERROR + ) + .map(|_| unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }) +} + +pub(crate) fn getpeername(socket: SysSocket) -> io::Result { + // Safety: zeroed `SOCKADDR_STORAGE` is valid. + let mut storage: SOCKADDR_STORAGE = unsafe { mem::zeroed() }; + let mut len = size_of_val(&storage) as c_int; + syscall!( + getpeername(socket, &mut storage as *mut _ as *mut _, &mut len), + PartialEq::eq, + sock::SOCKET_ERROR + ) + .map(|_| unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, len) }) +} + +pub(crate) fn try_clone(socket: SysSocket) -> io::Result { + let mut info: MaybeUninit = MaybeUninit::uninit(); + syscall!( + WSADuplicateSocketW(socket, GetCurrentProcessId(), info.as_mut_ptr()), + PartialEq::eq, + sock::SOCKET_ERROR + )?; + // Safety: `WSADuplicateSocketW` intialised `info` for us. + let mut info = unsafe { info.assume_init() }; + + syscall!( + WSASocketW( + info.iAddressFamily, + info.iSocketType, + info.iProtocol, + &mut info, + 0, + sock::WSA_FLAG_OVERLAPPED | sock::WSA_FLAG_NO_HANDLE_INHERIT, + ), + PartialEq::eq, + sock::INVALID_SOCKET + ) +} + +pub(crate) fn take_error(socket: SysSocket) -> io::Result> { + match unsafe { getsockopt::(socket, SOL_SOCKET, SO_ERROR) } { + Ok(0) => Ok(None), + Ok(errno) => Ok(Some(io::Error::from_raw_os_error(errno))), + Err(err) => Err(err), + } +} + +pub(crate) fn set_nonblocking(socket: SysSocket, nonblocking: bool) -> io::Result<()> { + let mut nonblocking = nonblocking as u_long; + ioctlsocket(socket, sock::FIONBIO, &mut nonblocking) +} + +pub(crate) fn shutdown(socket: SysSocket, how: Shutdown) -> io::Result<()> { + let how = match how { + Shutdown::Write => SD_SEND, + Shutdown::Read => SD_RECEIVE, + Shutdown::Both => SD_BOTH, + }; + syscall!(shutdown(socket, how), PartialEq::eq, sock::SOCKET_ERROR).map(|_| ()) +} + +/// Caller must ensure `T` is the correct type for `opt` and `val`. +unsafe fn getsockopt(socket: SysSocket, opt: c_int, val: c_int) -> io::Result { + let mut payload: MaybeUninit = MaybeUninit::uninit(); + let mut len = mem::size_of::() as c_int; + syscall!( + getsockopt(socket, opt, val, payload.as_mut_ptr().cast(), &mut len,), + PartialEq::eq, + sock::SOCKET_ERROR + ) + .map(|_| { + debug_assert_eq!(len as usize, mem::size_of::()); + // Safety: `getsockopt` initialised `payload` for us. + payload.assume_init() + }) +} + +fn ioctlsocket(socket: SysSocket, cmd: c_long, payload: &mut u_long) -> io::Result<()> { + syscall!( + ioctlsocket(socket, cmd, payload), + PartialEq::eq, + sock::SOCKET_ERROR + ) + .map(|_| ()) +} + +/// Windows only API. impl crate::Socket { /// Sets `HANDLE_FLAG_INHERIT` to zero using `SetHandleInformation`. + #[cfg(feature = "all")] pub fn set_no_inherit(&self) -> io::Result<()> { - let r = unsafe { SetHandleInformation(self.inner as HANDLE, HANDLE_FLAG_INHERIT, 0) }; - if r == 0 { - Err(last_error()) + // NOTE: can't use `syscall!` because it expects the function in the + // `sock::` path. + let res = + unsafe { SetHandleInformation(self.inner as HANDLE, winbase::HANDLE_FLAG_INHERIT, 0) }; + if res == 0 { + // Zero means error. + Err(io::Error::last_os_error()) } else { Ok(()) } @@ -199,94 +321,6 @@ pub struct Socket { } impl Socket { - pub fn local_addr(&self) -> io::Result { - unsafe { - let mut storage: SOCKADDR_STORAGE = mem::zeroed(); - let mut len = mem::size_of_val(&storage) as c_int; - if sock::getsockname(self.socket, &mut storage as *mut _ as *mut _, &mut len) != 0 { - return Err(last_error()); - } - Ok(SockAddr::from_raw_parts( - &storage as *const _ as *const _, - len, - )) - } - } - - pub fn peer_addr(&self) -> io::Result { - unsafe { - let mut storage: SOCKADDR_STORAGE = mem::zeroed(); - let mut len = mem::size_of_val(&storage) as c_int; - if sock::getpeername(self.socket, &mut storage as *mut _ as *mut _, &mut len) != 0 { - return Err(last_error()); - } - Ok(SockAddr::from_raw_parts( - &storage as *const _ as *const _, - len, - )) - } - } - - pub fn try_clone(&self) -> io::Result { - unsafe { - let mut info: sock::WSAPROTOCOL_INFOW = mem::zeroed(); - let r = sock::WSADuplicateSocketW(self.socket, GetCurrentProcessId(), &mut info); - if r != 0 { - return Err(io::Error::last_os_error()); - } - let socket = sock::WSASocketW( - info.iAddressFamily, - info.iSocketType, - info.iProtocol, - &mut info, - 0, - WSA_FLAG_OVERLAPPED, - ); - let socket = match socket { - sock::INVALID_SOCKET => return Err(last_error()), - n => Socket::from_raw_socket(n as RawSocket), - }; - socket.set_no_inherit()?; - Ok(socket) - } - } - - pub fn take_error(&self) -> io::Result> { - unsafe { - let raw: c_int = self.getsockopt(SOL_SOCKET, SO_ERROR)?; - if raw == 0 { - Ok(None) - } else { - Ok(Some(io::Error::from_raw_os_error(raw as i32))) - } - } - } - - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - unsafe { - let mut nonblocking = nonblocking as c_ulong; - let r = sock::ioctlsocket(self.socket, sock::FIONBIO as c_int, &mut nonblocking); - if r == 0 { - Ok(()) - } else { - Err(io::Error::last_os_error()) - } - } - } - - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { - let how = match how { - Shutdown::Write => SD_SEND, - Shutdown::Read => SD_RECEIVE, - Shutdown::Both => SD_BOTH, - }; - if unsafe { sock::shutdown(self.socket, how) == 0 } { - Ok(()) - } else { - Err(last_error()) - } - } - pub fn recv(&self, buf: &mut [u8], flags: c_int) -> io::Result { unsafe { let n = { @@ -838,17 +872,6 @@ impl Socket { } } - fn set_no_inherit(&self) -> io::Result<()> { - unsafe { - let r = SetHandleInformation(self.socket as HANDLE, HANDLE_FLAG_INHERIT, 0); - if r == 0 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - } - } - } - pub fn inner(self) -> SysSocket { self.socket } @@ -894,10 +917,10 @@ impl fmt::Debug for Socket { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut f = f.debug_struct("Socket"); f.field("socket", &self.socket); - if let Ok(addr) = self.local_addr() { + if let Ok(addr) = getsockname(self.socket) { f.field("local_addr", &addr); } - if let Ok(addr) = self.peer_addr() { + if let Ok(addr) = getpeername(self.socket) { f.field("peer_addr", &addr); } f.finish()