changeset 24:1941e9d9819c

Fix unsound manipulation of env vars Modifying env vars in multi-threaded process is unsound but this crate was neither checking the number of threads nor mark its functions as `unsafe`. This change fixes it by both adding a check and adding an `unsafe` function that can bypass that check if needed.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Fri, 28 Feb 2025 13:52:31 +0100
parents 729392c49b46
children 8e20daee41ed
files Cargo.toml src/error.rs src/lib.rs tests/comm.rs
diffstat 4 files changed, 188 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/Cargo.toml	Sat Jul 13 15:14:32 2024 +0200
+++ b/Cargo.toml	Fri Feb 28 13:52:31 2025 +0100
@@ -30,7 +30,7 @@
 serde_crate = { package = "serde", version = "1.0.116", optional = true, features = ["derive"] }
 serde_str_helpers = { version = "0.1.2", optional = true }
 parse_arg = { version = "0.1.4", optional = true }
-lazy_static = "1.4.0"
+once_cell = "1.13.0"
 tokio = { package = "tokio", version = "1.0.0", optional = true, features = ["net"] }
 tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp", "dns"] }
 tokio_0_3 = { package = "tokio", version = "0.3", optional = true, features = ["net"] }
--- a/src/error.rs	Sat Jul 13 15:14:32 2024 +0200
+++ b/src/error.rs	Fri Feb 28 13:52:31 2025 +0100
@@ -4,8 +4,33 @@
 
 
 use thiserror::Error;
+use std::fmt;
 use std::io;
 
+/// Error returned when the library initialization fails.
+#[derive(Debug)]
+pub struct InitError(pub(crate) InitErrorInner);
+
+#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
+type InitErrorInner = super::systemd_sockets::InitError;
+
+#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
+type InitErrorInner = std::convert::Infallible;
+
+impl fmt::Display for InitError {
+    #[inline]
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        fmt::Display::fmt(&self.0, f)
+    }
+}
+
+impl std::error::Error for InitError {
+    #[inline]
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        self.0.source()
+    }
+}
+
 /// Error that can occur during parsing of `SocketAddr` from a string
 ///
 /// This encapsulates possible errors that can occur when parsing the input.
--- a/src/lib.rs	Sat Jul 13 15:14:32 2024 +0200
+++ b/src/lib.rs	Fri Feb 28 13:52:31 2025 +0100
@@ -1,12 +1,15 @@
 //! A convenience crate for optionally supporting systemd socket activation.
 //! 
 //! ## About
+//!
+//! **Important:** because of various reasons it is recommended to call the [`init`] function at
+//! the start of your program!
 //! 
 //! The goal of this crate is to make socket activation with systemd in your project trivial.
 //! It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std`
 //! but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name.
 //! Then it provides a method to bind the address which will return the socket from systemd if available.
-//! 
+//!
 //! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag.
 //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
 //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
@@ -24,6 +27,7 @@
 //! use std::convert::TryFrom;
 //! use std::io::Write;
 //! 
+//! systemd_socket::init().expect("Failed to initialize systemd sockets");
 //! let mut args = std::env::args_os();
 //! let program_name = args.next().expect("unknown program name");
 //! let socket_addr = args.next().expect("missing socket address");
@@ -52,6 +56,20 @@
 //! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr`
 //! * `async_std` - adds `bind_async_std` method to `SocketAddr`
 //!
+//! ## Soundness
+//!
+//! The systemd file descriptors are transferred using environment variables and since they are
+//! file descriptors, they should have move semantics. However environment variables in Rust do not
+//! have move semantics and even modifying them is very dangerous.
+//!
+//! Because of this, the crate only allows initialization when there's only one thread running.
+//! However that still doesn't prevent all possible problems: if some other code closes file
+//! descriptors stored in those environment variables you can get an invalid socket.
+//!
+//! This situation is obviously ridiculous because there shouldn't be a reason to use another
+//! library to do the same thing. It could also be argued that whichever code doesn't clear the
+//! evironment variable is broken (even though understandably) and it's not a fault of this library.
+//!
 //! ## MSRV
 //!
 //! This crate must always compile with the latest Rust available in the latest Debian stable.
@@ -79,46 +97,131 @@
     use std::sync::Mutex;
     use libsystemd::activation::FileDescriptor;
     use libsystemd::errors::SdError as LibSystemdError;
-    type LibSystemdResult<T> = Result<T, LibSystemdError>;
 
     #[derive(Debug)]
-    pub(crate) struct Error(&'static Mutex<LibSystemdError>);
+    pub(crate) struct Error(&'static Mutex<InitError>);
 
     impl fmt::Display for Error {
         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-            fmt::Display::fmt(&*self.0.lock().expect("mutex poisoned"), f)
+            use std::error::Error as _;
+
+            let guard = self.0.lock().expect("mutex poisoned");
+            fmt::Display::fmt(&*guard, f)?;
+            let mut source_opt = guard.source();
+            while let Some(source) = source_opt {
+                write!(f, ": {}", source)?;
+                source_opt = source.source();
+            }
+            Ok(())
         }
     }
 
     // No source we can't keep the mutex locked
     impl std::error::Error for Error {}
 
+    pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> {
+        SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop)
+    }
+
     pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> {
-        match &*SYSTEMD_SOCKETS {
+        let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new));
+        match sockets {
             Ok(sockets) => Ok(sockets.take(name)),
             Err(error) => Err(Error(error))
         }
     }
 
+    #[derive(Debug)]
+    pub(crate) enum InitError {
+        OpenStatus(std::io::Error),
+        ReadStatus(std::io::Error),
+        ThreadCountNotFound,
+        MultipleThreads,
+        LibSystemd(LibSystemdError),
+    }
+
+    impl From<LibSystemdError> for InitError {
+        fn from(value: LibSystemdError) -> Self {
+            Self::LibSystemd(value)
+        }
+    }
+
+    impl fmt::Display for InitError {
+        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+            match self {
+                Self::OpenStatus(_) => write!(f, "failed to open /proc/self/status"),
+                Self::ReadStatus(_) => write!(f, "failed to read /proc/self/status"),
+                Self::ThreadCountNotFound => write!(f, "/proc/self/status doesn't contain Threads entry"),
+                Self::MultipleThreads => write!(f, "there is more than one thread running"),
+                // We have nothing to say about the error, let's flatten it
+                Self::LibSystemd(error) => fmt::Display::fmt(error, f),
+            }
+        }
+    }
+
+    impl std::error::Error for InitError {
+        fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+            match self {
+                Self::OpenStatus(error) => Some(error),
+                Self::ReadStatus(error) => Some(error),
+                Self::ThreadCountNotFound => None,
+                Self::MultipleThreads => None,
+                // We have nothing to say about the error, let's flatten it
+                Self::LibSystemd(error) => error.source(),
+            }
+        }
+    }
+
     struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>);
 
     impl SystemdSockets {
-        fn new() -> LibSystemdResult<Self> {
+        fn new_protected(explicit: bool) -> Result<Self, InitError> {
+            unsafe { Self::new(true, explicit) }
+        }
+
+        unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> {
+            if explicit {
+                if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() {
+                    // Systemd is not used - make the map empty
+                    return Ok(SystemdSockets(Mutex::new(Default::default())));
+                }
+            }
+
+            if protected { Self::check_single_thread()? }
                                                                             // MUST BE true FOR SAFETY!!!
-            let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ true)?.into_iter().map(|(fd, name)| (name, fd)).collect();
+            let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| (name, fd)).collect();
             Ok(SystemdSockets(Mutex::new(map)))
         }
 
+        fn check_single_thread() -> Result<(), InitError> {
+            use std::io::BufRead;
+
+            let status = std::fs::File::open("/proc/self/status").map_err(InitError::OpenStatus)?;
+            let mut status = std::io::BufReader::new(status);
+            let mut line = String::new();
+            loop {
+                if status.read_line(&mut line).map_err(InitError::ReadStatus)? == 0 {
+                    return Err(InitError::ThreadCountNotFound);
+                }
+                if let Some(threads) = line.strip_prefix("Threads:") {
+                    if threads.trim() == "1" {
+                        break;
+                    } else {
+                        return Err(InitError::MultipleThreads);
+                    }
+                }
+                line.clear();
+            }
+            Ok(())
+        }
+
         fn take(&self, name: &str) -> Option<FileDescriptor> {
             // MUST remove THE SOCKET FOR SAFETY!!!
             self.0.lock().expect("poisoned mutex").remove(name)
         }
     }
 
-    lazy_static::lazy_static! {
-        // We don't panic in order to let the application handle the error later
-        static ref SYSTEMD_SOCKETS: Result<SystemdSockets, Mutex<LibSystemdError>> = SystemdSockets::new().map_err(Mutex::new);
-    }
+    static SYSTEMD_SOCKETS: once_cell::sync::OnceCell<Result<SystemdSockets, Mutex<InitError>>> = once_cell::sync::OnceCell::new();
 }
 
 /// Socket address that can be an ordinary address or a systemd socket
@@ -342,6 +445,49 @@
     }
 }
 
+/// Initializes the library while there's only a single thread.
+///
+/// Unfortunately, this library has to be initialized and, for soundness, this initialization must
+/// happen when no other threads are running. This is attempted automatically when trying to bind a
+/// systemd socket but at that time there may be other threads running and error reporting also
+/// faces some restrictions. This function provides better control over the initialization point
+/// and returns a more idiomatic error type.
+///
+/// You should generally call this at around the top of `main`, where no threads were created yet.
+#[inline]
+pub fn init() -> Result<(), error::InitError> {
+    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
+    {
+        // Calling with true is always sound
+        unsafe { systemd_sockets::init(true) }.map_err(error::InitError)
+    }
+    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
+    {
+        Ok(())
+    }
+}
+
+/// Initializes the library without protection against double close.
+///
+/// Unfortunately, this library has to be initialized and, because double closing file descriptors
+/// is unsound, the library has some protections against double close. However these protections
+/// come with the limitation that the library must be initailized with a single thread.
+///
+/// If for any reason you're unable to call `init` in a single thread at around the top of `main`
+/// (and this should be almost never) you may call this method if you've ensured that no other part
+/// of your codebase is operating on systemd-provided file descriptors stored in the environment 
+/// variables.
+pub unsafe fn init_unprotected() -> Result<(), error::InitError> {
+    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
+    {
+        systemd_sockets::init(false).map_err(error::InitError)
+    }
+    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
+    {
+        Ok(())
+    }
+}
+
 /// Displays the address in format that can be parsed again.
 ///
 /// **Important: While I don't expect this impl to change, don't rely on it!**
--- a/tests/comm.rs	Sat Jul 13 15:14:32 2024 +0200
+++ b/tests/comm.rs	Fri Feb 28 13:52:31 2025 +0100
@@ -36,6 +36,11 @@
 fn main_slave(addr: &str) {
     use systemd_socket::SocketAddr;
 
+    // SAFETY: this is the only thread that's going to mess with systemd sockets.
+    unsafe {
+        systemd_socket::init_unprotected().unwrap();
+    }
+
     let socket = addr
         .parse::<SocketAddr>()
         .expect("failed to parse socket")