changeset 25:8e20daee41ed

Set CLOEXEC flag on the descriptors The received systemd descriptors don't have `O_CLOEXEC` set because they are received over `exec`. Thus if the process executes a child the child gets polluted with the descriptors. To prevent this, we set `O_CLOEXEC` during initialization. However this also required restructuring of the code because `libsystemd` doesn't provide temporary access to the descriptors - only permanent one. Thus we have to "validate" the descriptors eagerly. We still store the invalid ones as errors to make sure the errors get reported accurately.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Fri, 28 Feb 2025 21:11:19 +0100
parents 1941e9d9819c
children 0feab4f4c2ce
files Cargo.toml src/lib.rs
diffstat 2 files changed, 61 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/Cargo.toml	Fri Feb 28 13:52:31 2025 +0100
+++ b/Cargo.toml	Fri Feb 28 21:11:19 2025 +0100
@@ -35,3 +35,4 @@
 tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp", "dns"] }
 tokio_0_3 = { package = "tokio", version = "0.3", optional = true, features = ["net"] }
 async-std = { version = "1.7.0", optional = true }
+libc = "0.2.155"
--- a/src/lib.rs	Fri Feb 28 13:52:31 2025 +0100
+++ b/src/lib.rs	Fri Feb 28 21:11:19 2025 +0100
@@ -116,6 +116,8 @@
         }
     }
 
+    type StoredSocket = Result<Socket, ()>;
+
     // No source we can't keep the mutex locked
     impl std::error::Error for Error {}
 
@@ -123,7 +125,7 @@
         SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop)
     }
 
-    pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> {
+    pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> {
         let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new));
         match sockets {
             Ok(sockets) => Ok(sockets.take(name)),
@@ -172,7 +174,44 @@
         }
     }
 
-    struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>);
+    pub(crate) enum Socket {
+        TcpListener(std::net::TcpListener),
+    }
+
+    impl std::convert::TryFrom<FileDescriptor> for Socket {
+        type Error = ();
+
+        fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> {
+            use libsystemd::activation::IsType;
+            use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd};
+
+            fn set_cloexec(fd: std::os::unix::io::RawFd) {
+                // SAFETY: The function is a harmless syscall
+                let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
+                if flags != -1 && flags & libc::FD_CLOEXEC == 0 {
+                    // We ignore errors, since the FD is still usable
+                    // SAFETY: socket is definitely a valid file descriptor and setting CLOEXEC is
+                    // a sound operation.
+                    unsafe {
+                        libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
+                    }
+                }
+            }
+
+            if value.is_inet() {
+                // SAFETY: FileDescriptor is obtained from systemd, so it should be valid.
+                let socket = unsafe { std::net::TcpListener::from_raw_fd(value.into_raw_fd()) };
+                set_cloexec(socket.as_raw_fd());
+                Ok(Socket::TcpListener(socket))
+            } else {
+                // We still need to make the filedescriptor harmless.
+                set_cloexec(value.into_raw_fd());
+                Err(())
+            }
+        }
+    }
+
+    struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, StoredSocket>>);
 
     impl SystemdSockets {
         fn new_protected(explicit: bool) -> Result<Self, InitError> {
@@ -180,6 +219,8 @@
         }
 
         unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> {
+            use std::convert::TryFrom;
+
             if explicit {
                 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() {
                     // Systemd is not used - make the map empty
@@ -189,7 +230,9 @@
 
             if protected { Self::check_single_thread()? }
                                                                             // MUST BE true FOR SAFETY!!!
-            let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| (name, fd)).collect();
+            let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| {
+                (name, Socket::try_from(fd))
+            }).collect();
             Ok(SystemdSockets(Mutex::new(map)))
         }
 
@@ -215,7 +258,7 @@
             Ok(())
         }
 
-        fn take(&self, name: &str) -> Option<FileDescriptor> {
+        fn take(&self, name: &str) -> Option<StoredSocket> {
             // MUST remove THE SOCKET FOR SAFETY!!!
             self.0.lock().expect("poisoned mutex").remove(name)
         }
@@ -414,8 +457,7 @@
 
     #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
     fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
-        use libsystemd::activation::IsType;
-        use std::os::unix::io::{FromRawFd, IntoRawFd};
+        use systemd_sockets::Socket;
 
         let real_systemd_name = if prefixed {
             &socket_name[SYSTEMD_PREFIX.len()..]
@@ -424,16 +466,11 @@
         };
 
         let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
-        // Safety: The environment variable is unset, so that no other calls can get the
-        // descriptors. The descriptors are taken from the map, not cloned, so they can't
-        // be duplicated.
-        unsafe {
-            // match instead of combinators to avoid cloning socket_name
-            match socket {
-                Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
-                Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
-                None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
-            }
+        // match instead of combinators to avoid cloning socket_name
+        match socket {
+            Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))),
+            Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
+            None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
         }
     }
 
@@ -454,6 +491,9 @@
 /// and returns a more idiomatic error type.
 ///
 /// You should generally call this at around the top of `main`, where no threads were created yet.
+/// While technically, you may spawn a thread and call this function after that thread terminated,
+/// this has the additional problem that the descriptors are still around, so if that thread (or the
+/// current one!) forks and execs the descriptors will leak into the child.
 #[inline]
 pub fn init() -> Result<(), error::InitError> {
     #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
@@ -477,6 +517,10 @@
 /// (and this should be almost never) you may call this method if you've ensured that no other part
 /// of your codebase is operating on systemd-provided file descriptors stored in the environment 
 /// variables.
+///
+/// Note however that doing so uncovers another problem: if another thread forks and execs the
+/// systemd file descriptors will get passed into that program! In such case you somehow need to
+/// clean up the file descriptors yourself.
 pub unsafe fn init_unprotected() -> Result<(), error::InitError> {
     #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
     {