Mercurial > crates > systemd-socket
diff src/lib.rs @ 28:cfef4593e207
Run `cargo fmt`.
author | Paul Fisher <paul@pfish.zone> |
---|---|
date | Sat, 19 Apr 2025 01:33:50 -0400 |
parents | 0feab4f4c2ce |
children | efc69e99db70 |
line wrap: on
line diff
--- a/src/lib.rs Fri Feb 28 23:15:59 2025 +0100 +++ b/src/lib.rs Sat Apr 19 01:33:50 2025 -0400 @@ -1,10 +1,10 @@ //! A convenience crate for optionally supporting systemd socket activation. -//! +//! //! ## About //! //! **Important:** because of various reasons it is recommended to call the [`init`] function at //! the start of your program! -//! +//! //! The goal of this crate is to make socket activation with systemd in your project trivial. //! It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std` //! but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name. @@ -19,14 +19,14 @@ //! //! Further, the crate also provides methods for binding `tokio` 1.0, 0.2, 0.3, and `async_std` sockets if the appropriate features are //! activated. -//! +//! //! ## Example -//! +//! //! ```no_run //! use systemd_socket::SocketAddr; //! use std::convert::TryFrom; //! use std::io::Write; -//! +//! //! systemd_socket::init().expect("Failed to initialize systemd sockets"); //! let mut args = std::env::args_os(); //! let program_name = args.next().expect("unknown program name"); @@ -76,27 +76,26 @@ //! That is currently Rust 1.48.0. (Debian 11 - Bullseye) #![cfg_attr(docsrs, feature(doc_auto_cfg))] - #![deny(missing_docs)] pub mod error; mod resolv_addr; -use std::convert::{TryFrom, TryInto}; -use std::fmt; -use std::ffi::{OsStr, OsString}; use crate::error::*; use crate::resolv_addr::ResolvAddr; +use std::convert::{TryFrom, TryInto}; +use std::ffi::{OsStr, OsString}; +use std::fmt; #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] use std::convert::Infallible as Never; #[cfg(all(target_os = "linux", feature = "enable_systemd"))] pub(crate) mod systemd_sockets { + use libsystemd::activation::FileDescriptor; + use libsystemd::errors::SdError as LibSystemdError; use std::fmt; use std::sync::Mutex; - use libsystemd::activation::FileDescriptor; - use libsystemd::errors::SdError as LibSystemdError; #[derive(Debug)] pub(crate) struct Error(&'static Mutex<InitError>); @@ -122,14 +121,17 @@ impl std::error::Error for Error {} pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> { - SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop) + SYSTEMD_SOCKETS + .get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)) + .map(drop) } pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> { - let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); + let sockets = SYSTEMD_SOCKETS + .get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); match sockets { Ok(sockets) => Ok(sockets.take(name)), - Err(error) => Err(Error(error)) + Err(error) => Err(Error(error)), } } @@ -153,7 +155,9 @@ match self { Self::OpenStatus(_) => write!(f, "failed to open /proc/self/status"), Self::ReadStatus(_) => write!(f, "failed to read /proc/self/status"), - Self::ThreadCountNotFound => write!(f, "/proc/self/status doesn't contain Threads entry"), + Self::ThreadCountNotFound => { + write!(f, "/proc/self/status doesn't contain Threads entry") + } Self::MultipleThreads => write!(f, "there is more than one thread running"), // We have nothing to say about the error, let's flatten it Self::LibSystemd(error) => fmt::Display::fmt(error, f), @@ -183,7 +187,7 @@ fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> { use libsystemd::activation::IsType; - use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd}; + use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; fn set_cloexec(fd: std::os::unix::io::RawFd) { // SAFETY: The function is a harmless syscall @@ -222,17 +226,25 @@ use std::convert::TryFrom; if explicit { - if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() { + if std::env::var_os("LISTEN_PID").is_none() + && std::env::var_os("LISTEN_FDS").is_none() + && std::env::var_os("LISTEN_FDNAMES").is_none() + { // Systemd is not used - make the map empty return Ok(SystemdSockets(Mutex::new(Default::default()))); } } - if protected { Self::check_single_thread()? } - // MUST BE true FOR SAFETY!!! - let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| { - (name, Socket::try_from(fd)) - }).collect(); + if protected { + Self::check_single_thread()? + } + // MUST BE true FOR SAFETY!!! + let map = libsystemd::activation::receive_descriptors_with_names( + /*unset env = */ protected, + )? + .into_iter() + .map(|(fd, name)| (name, Socket::try_from(fd))) + .collect(); Ok(SystemdSockets(Mutex::new(map))) } @@ -264,7 +276,8 @@ } } - static SYSTEMD_SOCKETS: once_cell::sync::OnceCell<Result<SystemdSockets, Mutex<InitError>>> = once_cell::sync::OnceCell::new(); + static SYSTEMD_SOCKETS: once_cell::sync::OnceCell<Result<SystemdSockets, Mutex<InitError>>> = + once_cell::sync::OnceCell::new(); } /// Socket address that can be an ordinary address or a systemd socket @@ -277,7 +290,11 @@ /// Optional dependencies on `parse_arg` and `serde` make it trivial to use with /// [`configure_me`](https://crates.io/crates/configure_me). #[derive(Debug)] -#[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))] +#[cfg_attr( + feature = "serde", + derive(serde_crate::Deserialize), + serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr") +)] pub struct SocketAddr(SocketAddrInner); impl SocketAddr { @@ -297,15 +314,29 @@ }; let name_len = real_systemd_name.len(); - match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { + match real_systemd_name + .chars() + .enumerate() + .find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') + { None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))), - None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))), - None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()), - Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()), + None if name_len <= 255 && !prefixed => { + Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))) + } + None => Err(ParseErrorInner::LongSocketName { + string: name, + len: name_len, + } + .into()), + Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { + string: name, + c, + pos, + } + .into()), } } - #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> { Err(ParseError(ParseErrorInner::SystemdUnsupported(name))) @@ -319,14 +350,19 @@ match self.0 { SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { Ok(socket) => Ok(socket), - Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), + Err(error) => Err(BindErrorInner::BindFailed { addr, error }.into()), }, - SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { + SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) + { Ok(socket) => Ok(socket), - Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), + Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error }.into()), }, - SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket), - SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket), + SocketAddrInner::Systemd(socket_name) => { + Self::get_systemd(socket_name, true).map(|(socket, _)| socket) + } + SocketAddrInner::SystemdNoPrefix(socket_name) => { + Self::get_systemd(socket_name, false).map(|(socket, _)| socket) + } } } @@ -341,20 +377,30 @@ match self.0 { SocketAddrInner::Ordinary(addr) => match tokio::net::TcpListener::bind(addr).await { Ok(socket) => Ok(socket), - Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), + Err(error) => Err(TokioBindError::Bind( + BindErrorInner::BindFailed { addr, error }.into(), + )), }, - SocketAddrInner::WithHostname(addr) => match tokio::net::TcpListener::bind(addr.as_str()).await { - Ok(socket) => Ok(socket), - Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), - }, + SocketAddrInner::WithHostname(addr) => { + match tokio::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind( + BindErrorInner::BindOrResolvFailed { addr, error }.into(), + )), + } + } SocketAddrInner::Systemd(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name, true)?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) - }, + socket + .try_into() + .map_err(|error| TokioConversionError { addr, error }.into()) + } SocketAddrInner::SystemdNoPrefix(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name, false)?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) - }, + socket + .try_into() + .map_err(|error| TokioConversionError { addr, error }.into()) + } } } @@ -367,22 +413,34 @@ #[cfg(feature = "tokio_0_2")] pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { match self.0 { - SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await { - Ok(socket) => Ok(socket), - Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), - }, - SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { - Ok(socket) => Ok(socket), - Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), - }, + SocketAddrInner::Ordinary(addr) => { + match tokio_0_2::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind( + BindErrorInner::BindFailed { addr, error }.into(), + )), + } + } + SocketAddrInner::WithHostname(addr) => { + match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind( + BindErrorInner::BindOrResolvFailed { addr, error }.into(), + )), + } + } SocketAddrInner::Systemd(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name, true)?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) - }, + socket + .try_into() + .map_err(|error| TokioConversionError { addr, error }.into()) + } SocketAddrInner::SystemdNoPrefix(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name, false)?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) - }, + socket + .try_into() + .map_err(|error| TokioConversionError { addr, error }.into()) + } } } @@ -395,22 +453,34 @@ #[cfg(feature = "tokio_0_3")] pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { match self.0 { - SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await { - Ok(socket) => Ok(socket), - Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), - }, - SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { - Ok(socket) => Ok(socket), - Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), - }, + SocketAddrInner::Ordinary(addr) => { + match tokio_0_3::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind( + BindErrorInner::BindFailed { addr, error }.into(), + )), + } + } + SocketAddrInner::WithHostname(addr) => { + match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind( + BindErrorInner::BindOrResolvFailed { addr, error }.into(), + )), + } + } SocketAddrInner::Systemd(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name, true)?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) - }, + socket + .try_into() + .map_err(|error| TokioConversionError { addr, error }.into()) + } SocketAddrInner::SystemdNoPrefix(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name, false)?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) - }, + socket + .try_into() + .map_err(|error| TokioConversionError { addr, error }.into()) + } } } @@ -423,40 +493,52 @@ #[cfg(feature = "async-std")] pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { match self.0 { - SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await { - Ok(socket) => Ok(socket), - Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), - }, - SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { - Ok(socket) => Ok(socket), - Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), - }, + SocketAddrInner::Ordinary(addr) => { + match async_std::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindFailed { addr, error }.into()), + } + } + SocketAddrInner::WithHostname(addr) => { + match async_std::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error }.into()), + } + } SocketAddrInner::Systemd(socket_name) => { let (socket, _) = Self::get_systemd(socket_name, true)?; Ok(socket.into()) - }, + } SocketAddrInner::SystemdNoPrefix(socket_name) => { let (socket, _) = Self::get_systemd(socket_name, false)?; Ok(socket.into()) - }, + } } } // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan // rules. - fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { + fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> + where + T: 'a + std::ops::Deref<Target = str> + Into<String>, + { if string.starts_with(SYSTEMD_PREFIX) { Self::inner_from_systemd_name(string.into(), true) } else { match string.parse() { Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), - Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), + Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname( + ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?, + ))), } } } #[cfg(all(target_os = "linux", feature = "enable_systemd"))] - fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + fn get_systemd( + socket_name: String, + prefixed: bool, + ) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { use systemd_sockets::Socket; let real_systemd_name = if prefixed { @@ -465,19 +547,25 @@ &socket_name }; - let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; + let socket = + systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; // match instead of combinators to avoid cloning socket_name match socket { - Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))), + Some(Ok(Socket::TcpListener(socket))) => { + Ok((socket, SocketAddrInner::Systemd(socket_name))) + } Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), - None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) + None => Err(BindErrorInner::MissingDescriptor(socket_name).into()), } } // This approach makes the rest of the code much simpler as it doesn't require sprinkling it // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] - fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + fn get_systemd( + socket_name: Never, + _prefixed: bool, + ) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { match socket_name {} } } @@ -515,7 +603,7 @@ /// /// If for any reason you're unable to call `init` in a single thread at around the top of `main` /// (and this should be almost never) you may call this method if you've ensured that no other part -/// of your codebase is operating on systemd-provided file descriptors stored in the environment +/// of your codebase is operating on systemd-provided file descriptors stored in the environment /// variables. /// /// Note however that doing so uncovers another problem: if another thread forks and execs the @@ -625,7 +713,10 @@ type Error = ParseOsStrError; fn try_from(s: &'a OsStr) -> Result<Self, Self::Error> { - s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into) + s.to_str() + .ok_or(ParseOsStrError::InvalidUtf8)? + .try_into() + .map_err(Into::into) } } @@ -633,7 +724,10 @@ type Error = ParseOsStrError; fn try_from(s: OsString) -> Result<Self, Self::Error> { - s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into) + s.into_string() + .map_err(|_| ParseOsStrError::InvalidUtf8)? + .try_into() + .map_err(Into::into) } } @@ -670,13 +764,19 @@ #[test] fn parse_ordinary() { - assert_eq!("127.0.0.1:42".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into())); + assert_eq!( + "127.0.0.1:42".parse::<SocketAddr>().unwrap().0, + SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into()) + ); } #[test] #[cfg(all(target_os = "linux", feature = "enable_systemd"))] fn parse_systemd() { - assert_eq!("systemd://foo".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned())); + assert_eq!( + "systemd://foo".parse::<SocketAddr>().unwrap().0, + SocketAddrInner::Systemd("systemd://foo".to_owned()) + ); } #[test] @@ -711,7 +811,10 @@ } #[test] - #[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)] + #[cfg_attr( + not(all(target_os = "linux", feature = "enable_systemd")), + should_panic + )] fn no_prefix_parse_systemd() { SocketAddr::from_systemd_name("foo").unwrap(); }