comparison src/lib.rs @ 25:8e20daee41ed

Set CLOEXEC flag on the descriptors The received systemd descriptors don't have `O_CLOEXEC` set because they are received over `exec`. Thus if the process executes a child the child gets polluted with the descriptors. To prevent this, we set `O_CLOEXEC` during initialization. However this also required restructuring of the code because `libsystemd` doesn't provide temporary access to the descriptors - only permanent one. Thus we have to "validate" the descriptors eagerly. We still store the invalid ones as errors to make sure the errors get reported accurately.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Fri, 28 Feb 2025 21:11:19 +0100
parents 1941e9d9819c
children 0feab4f4c2ce
comparison
equal deleted inserted replaced
24:1941e9d9819c 25:8e20daee41ed
114 } 114 }
115 Ok(()) 115 Ok(())
116 } 116 }
117 } 117 }
118 118
119 type StoredSocket = Result<Socket, ()>;
120
119 // No source we can't keep the mutex locked 121 // No source we can't keep the mutex locked
120 impl std::error::Error for Error {} 122 impl std::error::Error for Error {}
121 123
122 pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> { 124 pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> {
123 SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop) 125 SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop)
124 } 126 }
125 127
126 pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> { 128 pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> {
127 let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); 129 let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new));
128 match sockets { 130 match sockets {
129 Ok(sockets) => Ok(sockets.take(name)), 131 Ok(sockets) => Ok(sockets.take(name)),
130 Err(error) => Err(Error(error)) 132 Err(error) => Err(Error(error))
131 } 133 }
170 Self::LibSystemd(error) => error.source(), 172 Self::LibSystemd(error) => error.source(),
171 } 173 }
172 } 174 }
173 } 175 }
174 176
175 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>); 177 pub(crate) enum Socket {
178 TcpListener(std::net::TcpListener),
179 }
180
181 impl std::convert::TryFrom<FileDescriptor> for Socket {
182 type Error = ();
183
184 fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> {
185 use libsystemd::activation::IsType;
186 use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd};
187
188 fn set_cloexec(fd: std::os::unix::io::RawFd) {
189 // SAFETY: The function is a harmless syscall
190 let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
191 if flags != -1 && flags & libc::FD_CLOEXEC == 0 {
192 // We ignore errors, since the FD is still usable
193 // SAFETY: socket is definitely a valid file descriptor and setting CLOEXEC is
194 // a sound operation.
195 unsafe {
196 libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
197 }
198 }
199 }
200
201 if value.is_inet() {
202 // SAFETY: FileDescriptor is obtained from systemd, so it should be valid.
203 let socket = unsafe { std::net::TcpListener::from_raw_fd(value.into_raw_fd()) };
204 set_cloexec(socket.as_raw_fd());
205 Ok(Socket::TcpListener(socket))
206 } else {
207 // We still need to make the filedescriptor harmless.
208 set_cloexec(value.into_raw_fd());
209 Err(())
210 }
211 }
212 }
213
214 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, StoredSocket>>);
176 215
177 impl SystemdSockets { 216 impl SystemdSockets {
178 fn new_protected(explicit: bool) -> Result<Self, InitError> { 217 fn new_protected(explicit: bool) -> Result<Self, InitError> {
179 unsafe { Self::new(true, explicit) } 218 unsafe { Self::new(true, explicit) }
180 } 219 }
181 220
182 unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> { 221 unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> {
222 use std::convert::TryFrom;
223
183 if explicit { 224 if explicit {
184 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() { 225 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() {
185 // Systemd is not used - make the map empty 226 // Systemd is not used - make the map empty
186 return Ok(SystemdSockets(Mutex::new(Default::default()))); 227 return Ok(SystemdSockets(Mutex::new(Default::default())));
187 } 228 }
188 } 229 }
189 230
190 if protected { Self::check_single_thread()? } 231 if protected { Self::check_single_thread()? }
191 // MUST BE true FOR SAFETY!!! 232 // MUST BE true FOR SAFETY!!!
192 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| (name, fd)).collect(); 233 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| {
234 (name, Socket::try_from(fd))
235 }).collect();
193 Ok(SystemdSockets(Mutex::new(map))) 236 Ok(SystemdSockets(Mutex::new(map)))
194 } 237 }
195 238
196 fn check_single_thread() -> Result<(), InitError> { 239 fn check_single_thread() -> Result<(), InitError> {
197 use std::io::BufRead; 240 use std::io::BufRead;
213 line.clear(); 256 line.clear();
214 } 257 }
215 Ok(()) 258 Ok(())
216 } 259 }
217 260
218 fn take(&self, name: &str) -> Option<FileDescriptor> { 261 fn take(&self, name: &str) -> Option<StoredSocket> {
219 // MUST remove THE SOCKET FOR SAFETY!!! 262 // MUST remove THE SOCKET FOR SAFETY!!!
220 self.0.lock().expect("poisoned mutex").remove(name) 263 self.0.lock().expect("poisoned mutex").remove(name)
221 } 264 }
222 } 265 }
223 266
412 } 455 }
413 } 456 }
414 457
415 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] 458 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
416 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { 459 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
417 use libsystemd::activation::IsType; 460 use systemd_sockets::Socket;
418 use std::os::unix::io::{FromRawFd, IntoRawFd};
419 461
420 let real_systemd_name = if prefixed { 462 let real_systemd_name = if prefixed {
421 &socket_name[SYSTEMD_PREFIX.len()..] 463 &socket_name[SYSTEMD_PREFIX.len()..]
422 } else { 464 } else {
423 &socket_name 465 &socket_name
424 }; 466 };
425 467
426 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; 468 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
427 // Safety: The environment variable is unset, so that no other calls can get the 469 // match instead of combinators to avoid cloning socket_name
428 // descriptors. The descriptors are taken from the map, not cloned, so they can't 470 match socket {
429 // be duplicated. 471 Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))),
430 unsafe { 472 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
431 // match instead of combinators to avoid cloning socket_name 473 None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
432 match socket {
433 Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
434 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
435 None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
436 }
437 } 474 }
438 } 475 }
439 476
440 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it 477 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it
441 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. 478 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute.
452 /// systemd socket but at that time there may be other threads running and error reporting also 489 /// systemd socket but at that time there may be other threads running and error reporting also
453 /// faces some restrictions. This function provides better control over the initialization point 490 /// faces some restrictions. This function provides better control over the initialization point
454 /// and returns a more idiomatic error type. 491 /// and returns a more idiomatic error type.
455 /// 492 ///
456 /// You should generally call this at around the top of `main`, where no threads were created yet. 493 /// You should generally call this at around the top of `main`, where no threads were created yet.
494 /// While technically, you may spawn a thread and call this function after that thread terminated,
495 /// this has the additional problem that the descriptors are still around, so if that thread (or the
496 /// current one!) forks and execs the descriptors will leak into the child.
457 #[inline] 497 #[inline]
458 pub fn init() -> Result<(), error::InitError> { 498 pub fn init() -> Result<(), error::InitError> {
459 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] 499 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
460 { 500 {
461 // Calling with true is always sound 501 // Calling with true is always sound
475 /// 515 ///
476 /// If for any reason you're unable to call `init` in a single thread at around the top of `main` 516 /// If for any reason you're unable to call `init` in a single thread at around the top of `main`
477 /// (and this should be almost never) you may call this method if you've ensured that no other part 517 /// (and this should be almost never) you may call this method if you've ensured that no other part
478 /// of your codebase is operating on systemd-provided file descriptors stored in the environment 518 /// of your codebase is operating on systemd-provided file descriptors stored in the environment
479 /// variables. 519 /// variables.
520 ///
521 /// Note however that doing so uncovers another problem: if another thread forks and execs the
522 /// systemd file descriptors will get passed into that program! In such case you somehow need to
523 /// clean up the file descriptors yourself.
480 pub unsafe fn init_unprotected() -> Result<(), error::InitError> { 524 pub unsafe fn init_unprotected() -> Result<(), error::InitError> {
481 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] 525 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
482 { 526 {
483 systemd_sockets::init(false).map_err(error::InitError) 527 systemd_sockets::init(false).map_err(error::InitError)
484 } 528 }