diff --git a/Cargo.toml b/Cargo.toml index e99607c..1a813b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,9 +59,13 @@ version = "0.5.0" easy-parallel = "3.1.0" fastrand = "2.0.0" socket2 = "0.6.0" +tracing-subscriber = "0.3" [target.'cfg(unix)'.dev-dependencies] libc = "0.2" [target.'cfg(all(unix, not(target_os="vita")))'.dev-dependencies] signal-hook = "0.3.17" + +[target.'cfg(windows)'.dev-dependencies] +tempfile = "3.7" diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs index f7c7c3f..33b3294 100644 --- a/src/iocp/afd.rs +++ b/src/iocp/afd.rs @@ -25,11 +25,11 @@ use windows_sys::Win32::Networking::WinSock::{ }; use windows_sys::Win32::Storage::FileSystem::{FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE}; use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress}; -use windows_sys::Win32::System::IO::IO_STATUS_BLOCK; +use windows_sys::Win32::System::IO::{IO_STATUS_BLOCK, OVERLAPPED}; #[derive(Default)] #[repr(C)] -pub(super) struct AfdPollInfo { +pub(crate) struct AfdPollInfo { /// The timeout for this poll. timeout: i64, @@ -535,12 +535,27 @@ where } } +// The OVERLAPPED struct is larger than the IO_STATUS_BLOCK struct. +// This way it is possible to use the memory as either/or without +// the risk of reading or writing out-of-bounds. +#[repr(C)] +pub(crate) union PaddedIOStatusBlock { + pub(crate) io_status_block: IO_STATUS_BLOCK, + pub(crate) overlapped: OVERLAPPED, +} + +impl Default for PaddedIOStatusBlock { + fn default() -> Self { + unsafe { core::mem::zeroed() } + } +} + pin_project_lite::pin_project! { /// An I/O status block paired with some auxiliary data. #[repr(C)] - pub(super) struct IoStatusBlock { + pub(crate) struct IoStatusBlock { // The I/O status block. - iosb: UnsafeCell, + padded_io_status_block: UnsafeCell, // Whether or not the block is in use. in_use: AtomicBool, @@ -571,7 +586,7 @@ unsafe impl Sync for IoStatusBlock {} impl From for IoStatusBlock { fn from(data: T) -> Self { Self { - iosb: UnsafeCell::new(unsafe { std::mem::zeroed() }), + padded_io_status_block: UnsafeCell::new(unsafe { std::mem::zeroed() }), in_use: AtomicBool::new(false), data, _marker: PhantomPinned, @@ -580,8 +595,8 @@ impl From for IoStatusBlock { } impl IoStatusBlock { - pub(super) fn iosb(self: Pin<&Self>) -> &UnsafeCell { - self.project_ref().iosb + pub(crate) fn padded_io_status_block(self: Pin<&Self>) -> &UnsafeCell { + self.project_ref().padded_io_status_block } pub(super) fn data(self: Pin<&Self>) -> Pin<&T> { diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index facbe06..a8771cf 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -28,9 +28,12 @@ mod afd; mod port; -use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock}; +use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo}; use port::{IoCompletionPort, OverlappedEntry}; +pub(crate) use afd::IoStatusBlock; +pub(crate) use port::{Completion, CompletionHandle}; + use windows_sys::Win32::Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED}; use windows_sys::Win32::System::Threading::{ RegisterWaitForSingleObject, UnregisterWait, INFINITE, WT_EXECUTELONGFUNCTION, @@ -511,7 +514,7 @@ impl Poller { } /// Push an IOCP packet into the queue. - pub(super) fn post(&self, packet: CompletionPacket) -> io::Result<()> { + pub(super) fn post(&self, packet: crate::os::iocp::CompletionPacket) -> io::Result<()> { self.port.post(0, 0, packet.0) } @@ -709,38 +712,17 @@ impl EventExtra { } } -/// A packet used to wake up the poller with an event. -#[derive(Debug, Clone)] -pub struct CompletionPacket(Packet); - -impl CompletionPacket { - /// Create a new completion packet with a custom event. - pub fn new(event: Event) -> Self { - Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event }))) - } - - /// Get the event associated with this packet. - pub fn event(&self) -> &Event { - let data = self.0.as_ref().data().project_ref(); - - match data { - PacketInnerProj::Custom { event } => event, - _ => unreachable!(), - } - } -} - /// The type of our completion packet. /// /// It needs to be pinned, since it contains data that is expected by IOCP not to be moved. -type Packet = Pin>; +pub(crate) type Packet = Pin>; type PacketUnwrapped = IoStatusBlock; pin_project! { /// The inner type of the packet. #[project_ref = PacketInnerProj] #[project = PacketInnerProjMut] - enum PacketInner { + pub(crate) enum PacketInner { // A packet for a socket. Socket { // The AFD packet state. @@ -796,6 +778,16 @@ impl HasAfdInfo for PacketInner { } impl PacketUnwrapped { + /// If this is an event packet, get the event. + pub(crate) fn event(self: Pin<&Self>) -> &Event { + let data = self.data().project_ref(); + + match data { + PacketInnerProj::Custom { event } => event, + _ => unreachable!(), + } + } + /// Set the new events that this socket is waiting on. /// /// Returns `true` if we need to be updated. @@ -995,10 +987,10 @@ impl PacketUnwrapped { unsafe { // SAFETY: The packet is not in transit. - let iosb = &mut *self.as_ref().iosb().get(); + let iosb = &mut *self.as_ref().padded_io_status_block().get(); // Check the status. - match iosb.Anonymous.Status { + match iosb.io_status_block.Anonymous.Status { STATUS_CANCELLED => { // Poll request was cancelled. } @@ -1113,7 +1105,7 @@ impl PacketUnwrapped { /// Per-socket state. #[derive(Debug)] -struct SocketState { +pub(crate) struct SocketState { /// The raw socket handle. socket: RawSocket, @@ -1158,7 +1150,7 @@ enum SocketStatus { /// Per-waitable handle state. #[derive(Debug)] -struct WaitableState { +pub(crate) struct WaitableState { /// The handle that this state is for. handle: RawHandle, diff --git a/src/iocp/port.rs b/src/iocp/port.rs index 6d9b8be..2887f50 100644 --- a/src/iocp/port.rs +++ b/src/iocp/port.rs @@ -27,7 +27,7 @@ use windows_sys::Win32::System::IO::{ /// # Safety /// /// This must be a valid completion block. -pub(super) unsafe trait Completion { +pub(crate) unsafe trait Completion { /// Signal to the completion block that we are about to start an operation. fn try_lock(self: Pin<&Self>) -> bool; @@ -40,7 +40,7 @@ pub(super) unsafe trait Completion { /// # Safety /// /// This must be a valid completion block. -pub(super) unsafe trait CompletionHandle: Deref + Sized { +pub(crate) unsafe trait CompletionHandle: Deref + Sized { /// Type of the completion block. type Completion: Completion; diff --git a/src/os/iocp.rs b/src/os/iocp.rs index 3370118..2251251 100644 --- a/src/os/iocp.rs +++ b/src/os/iocp.rs @@ -1,6 +1,6 @@ //! Functionality that is only available for IOCP-based platforms. -pub use crate::sys::CompletionPacket; +use crate::sys::{Completion, CompletionHandle, IoStatusBlock, Packet, PacketInner}; use super::__private::PollerSealed; use crate::{Event, PollMode, Poller}; @@ -8,6 +8,61 @@ use crate::{Event, PollMode, Poller}; use std::io; use std::os::windows::io::{AsRawHandle, RawHandle}; use std::os::windows::prelude::{AsHandle, BorrowedHandle}; +use std::pin::Pin; +use std::sync::Arc; + +/// A packet used to wake up the poller with an event. +#[derive(Debug, Clone)] +pub struct CompletionPacket(pub(crate) Packet); + +impl CompletionPacket { + /// Create a new completion packet with a custom event. + pub fn new(event: Event) -> Self { + Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event }))) + } + + /// Get the event associated with this packet. + pub fn event(&self) -> &Event { + self.0.as_ref().event() + } + + /// Get a pointer to the underlying I/O status block. + /// + /// This pointer can be used as an `OVERLAPPED` block in Windows APIs. Calling this function + /// marks the block as "in use". Trying to call this function again before the operation is + /// indicated as complete by the poller will result in a panic. + pub fn as_overlapped_ptr(&self) -> *mut () { + if !self.0.as_ref().get().try_lock() { + panic!("completion packet is already in use"); + } + // The key point here is to increment the Arc reference count by cloning it. + // Otherwise, the Arc<> will be dropped in the method Poller::wait_deadline + // after it is re-created via from_raw() once the overlapped io has completed. + unsafe { + Arc::into_raw(Pin::into_inner_unchecked(self.0.clone())) as *mut () + } + } + + /// Get the number of transferred bytes after an OVERLAPPED IO has finished. + pub fn transferred_bytes(&self) -> usize { + if !self.0.as_ref().get().try_lock() { + panic!("completion packet is currently in use"); + } + + unsafe { + (*self.0.as_ref().padded_io_status_block().get()).overlapped.InternalHigh + } + } + + /// Cancel the in flight operation. + /// + /// # Safety + /// + /// The packet must be in flight and the operation must be cancelled already. + pub unsafe fn cancel(&mut self) { + self.0.as_ref().get().unlock(); + } +} /// Extension trait for the [`Poller`] type that provides functionality specific to IOCP-based /// platforms. diff --git a/tests/windows_overlapped.rs b/tests/windows_overlapped.rs new file mode 100644 index 0000000..0564f10 --- /dev/null +++ b/tests/windows_overlapped.rs @@ -0,0 +1,180 @@ +//! Take advantage of overlapped I/O on Windows using CompletionPacket. +#![cfg(windows)] + +use polling::os::iocp::CompletionPacket; +use polling::{Event, Events, Poller}; + +use std::io; +use std::os::windows::ffi::OsStrExt; +use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle}; +use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wfs, System::IO as wio}; + +#[test] +fn win32_file_io() { + // Create a poller. + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + // Open a file for writing. + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("test.txt"); + let fname = file_path + .as_os_str() + .encode_wide() + .chain(Some(0)) + .collect::>(); + let file_handle = unsafe { + let raw_handle = wfs::CreateFileW( + fname.as_ptr(), + wf::GENERIC_WRITE | wf::GENERIC_READ, + 0, + std::ptr::null_mut(), + wfs::CREATE_ALWAYS, + wfs::FILE_FLAG_OVERLAPPED, + std::ptr::null_mut(), + ); + + if raw_handle == wf::INVALID_HANDLE_VALUE { + panic!("CreateFileW failed: {}", io::Error::last_os_error()); + } + + OwnedHandle::from_raw_handle(raw_handle as _) + }; + + // Associate this file with the poller. + unsafe { + let poller_handle = poller.as_raw_handle(); + if wio::CreateIoCompletionPort( + file_handle.as_raw_handle() as _, + poller_handle as _, + 1, + 0) == std::ptr::null_mut() + { + panic!( + "CreateIoCompletionPort failed: {}", + io::Error::last_os_error() + ); + } + } + + // Repeatedly write to the pipe. + let input_text = "Now is the time for all good men to come to the aid of their party"; + let write_buffer : &[u8] = input_text.as_bytes(); + let mut write_buffer_cursor = & *write_buffer; + let mut len = input_text.len(); + + let write_packet = CompletionPacket::new(Event::writable(2)); + + while len > 0 { + // Begin to write. + let ptr = write_packet.as_overlapped_ptr().cast(); + unsafe { + if wfs::WriteFile( + file_handle.as_raw_handle() as _, + write_buffer_cursor.as_ptr() as _, + len as _, + std::ptr::null_mut(), + ptr) == 0 && wf::GetLastError() != wf::ERROR_IO_PENDING + { + panic!("WriteFile failed: {}", io::Error::last_os_error()); + } + } + + + // Wait for the overlapped operation to complete. + 'waiter: loop { + events.clear(); + println!("Starting wait..."); + poller.wait(&mut events, None).unwrap(); + println!("Got events"); + + for event in events.iter() { + if event.writable && event.key == 2 { + let bytes_written = write_packet.transferred_bytes(); + write_buffer_cursor = & write_buffer_cursor[bytes_written as usize..]; + len -= bytes_written as usize; + break 'waiter; + } + } + } + } + + + // Close the file and re-open it for reading. + drop(file_handle); + let file_handle = unsafe { + let raw_handle = wfs::CreateFileW( + fname.as_ptr(), + wf::GENERIC_READ | wf::GENERIC_WRITE, + 0, + std::ptr::null_mut(), + wfs::OPEN_EXISTING, + wfs::FILE_FLAG_OVERLAPPED, + std::ptr::null_mut(), + ); + + if raw_handle == wf::INVALID_HANDLE_VALUE { + panic!("CreateFileW failed: {}", io::Error::last_os_error()); + } + + OwnedHandle::from_raw_handle(raw_handle as _) + }; + + // Associate this file with the poller. + unsafe { + let poller_handle = poller.as_raw_handle(); + if wio::CreateIoCompletionPort( + file_handle.as_raw_handle() as _, + poller_handle as _, + 2, + 0) == std::ptr::null_mut() + { + panic!( + "CreateIoCompletionPort failed: {}", + io::Error::last_os_error() + ); + } + } + + // Repeatedly read from the pipe. + let mut buffer = vec![0u8; 1024]; + let mut buffer_cursor = &mut *buffer; + let mut len = 1024; + let mut bytes_received = 0; + + let read_packet = CompletionPacket::new(Event::readable(1)); + while bytes_received < input_text.len() { + // Begin the read. + let ptr = read_packet.as_overlapped_ptr().cast(); + unsafe { + if wfs::ReadFile( + file_handle.as_raw_handle() as _, + buffer_cursor.as_mut_ptr() as _, + len as _, + std::ptr::null_mut(), + ptr) == 0 && wf::GetLastError() != wf::ERROR_IO_PENDING + { + panic!("ReadFile failed: {}", io::Error::last_os_error()); + } + } + + // Wait for the overlapped operation to complete. + 'waiter: loop { + events.clear(); + poller.wait(&mut events, None).unwrap(); + + for event in events.iter() { + if event.readable && event.key == 1 { + let bytes_read = read_packet.transferred_bytes(); + buffer_cursor = &mut buffer_cursor[bytes_read ..]; + len -= bytes_read; + bytes_received += bytes_read; + break 'waiter; + } + } + } + } + + assert_eq!(bytes_received, input_text.len()); + assert_eq!(&buffer[..bytes_received], input_text.as_bytes()); +}