From 293c057d426a368b0b4ff3d51b04795a4caf72e9 Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Mon, 30 Dec 2024 13:11:43 +0100 Subject: [PATCH] feat: concrete stream type --- Cargo.toml | 5 ++- src/stream.rs | 103 +++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 105ed61..cacafab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mdev" -version = "0.1.1" +version = "0.2.0" edition = "2021" description = "mini-udev workalike" @@ -10,10 +10,11 @@ bytes = "1.9.0" clap = { version = "4.5.23", features = ["derive", "wrap_help"] } fork = "0.2.0" futures-util = "0.3.31" -kobject-uevent = "0.1.1" +kobject-uevent = "0.2.0" mdev-parser = "0.1.1" netlink-sys = { version = "0.8.7", features = ["tokio_socket"] } nix = { version = "0.29.0", features = ["user", "fs"] } +thiserror = "2.0.9" tokio = { version = "1.42.0", features = [ "macros", "rt-multi-thread", diff --git a/src/stream.rs b/src/stream.rs index 0bb9453..c27c997 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,38 +1,95 @@ -use std::process; +use std::{ + future::Future, + io, + pin::Pin, + process, + task::{ready, Context, Poll}, +}; -use anyhow::anyhow; -use futures_util::stream::{unfold, Stream}; +use futures_util::{FutureExt, Stream}; use kobject_uevent::UEvent; use netlink_sys::{ protocols::NETLINK_KOBJECT_UEVENT, AsyncSocket, AsyncSocketExt, SocketAddr, TokioSocket, }; +type FutureOutput = (TokioSocket, Result, io::Error>); + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Socket open error: {0}")] + Open(io::Error), + #[error("Socket bind error: {0}")] + Bind(io::Error), + #[error("Socket receive error: {0}")] + Receive(io::Error), + #[error(transparent)] + NetlinkPacket(kobject_uevent::Error), +} + /// creates a new stream of UEvents -pub fn uevents() -> anyhow::Result>> { - let mut socket = TokioSocket::new(NETLINK_KOBJECT_UEVENT) - .map_err(|e| anyhow!("Socket open error: {}", e))?; +pub fn uevents() -> Result>, Error> { + let mut socket = TokioSocket::new(NETLINK_KOBJECT_UEVENT).map_err(Error::Open)?; let sa = SocketAddr::new(process::id(), 1); - socket - .socket_mut() - .bind(&sa) - .map_err(|e| anyhow!("Socket bind error: {}", e))?; - - Ok(unfold( - (socket, bytes::BytesMut::with_capacity(1024 * 8)), - |(mut socket, mut buf)| async move { - buf.clear(); - match socket.recv_from(&mut buf).await { - Ok(_addr) => { + socket.socket_mut().bind(&sa).map_err(Error::Bind)?; + + Ok(UEventsStream::new(socket)) +} + +enum UEventsStream { + Socket(TokioSocket), + Future(Pin>>), + None, +} + +impl UEventsStream { + pub fn new(socket: TokioSocket) -> Self { + Self::Socket(socket) + } + + fn take_socket(&mut self) -> Option { + if matches!(self, Self::Socket(_)) { + let Self::Socket(socket) = std::mem::replace(self, Self::None) else { + unreachable!(); + }; + Some(socket) + } else { + None + } + } +} + +impl Stream for UEventsStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(mut socket) = this.take_socket() { + *this = Self::Future(Box::pin(async move { + let res = socket.recv_from_full().await.map(|(buf, _)| buf); + (socket, res) + })); + } + + if let Self::Future(fut) = this { + let (socket, res) = ready!(fut.poll_unpin(cx)); + *this = Self::Socket(socket); + match res { + Ok(buf) => { if buf.is_empty() { - return None; + return Poll::Ready(None); + } else { + return Poll::Ready(Some( + UEvent::from_netlink_packet(&buf).map_err(Error::NetlinkPacket), + )); } } Err(e) => { - return Some((Err(anyhow!("Socket receive error: {}", e)), (socket, buf))); + return Poll::Ready(Some(Err(Error::Receive(e)))); } - }; + } + } - Some((UEvent::from_netlink_packet(&buf), (socket, buf))) - }, - )) + Poll::Pending + } }