Mercurial > crates > systemd-socket
comparison src/lib.rs @ 25:8e20daee41ed
Set CLOEXEC flag on the descriptors
The received systemd descriptors don't have `O_CLOEXEC` set because they
are received over `exec`. Thus if the process executes a child the child
gets polluted with the descriptors.
To prevent this, we set `O_CLOEXEC` during initialization. However this
also required restructuring of the code because `libsystemd` doesn't
provide temporary access to the descriptors - only permanent one. Thus
we have to "validate" the descriptors eagerly. We still store the
invalid ones as errors to make sure the errors get reported accurately.
| author | Martin Habovstiak <martin.habovstiak@gmail.com> |
|---|---|
| date | Fri, 28 Feb 2025 21:11:19 +0100 |
| parents | 1941e9d9819c |
| children | 0feab4f4c2ce |
comparison
equal
deleted
inserted
replaced
| 24:1941e9d9819c | 25:8e20daee41ed |
|---|---|
| 114 } | 114 } |
| 115 Ok(()) | 115 Ok(()) |
| 116 } | 116 } |
| 117 } | 117 } |
| 118 | 118 |
| 119 type StoredSocket = Result<Socket, ()>; | |
| 120 | |
| 119 // No source we can't keep the mutex locked | 121 // No source we can't keep the mutex locked |
| 120 impl std::error::Error for Error {} | 122 impl std::error::Error for Error {} |
| 121 | 123 |
| 122 pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> { | 124 pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> { |
| 123 SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop) | 125 SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop) |
| 124 } | 126 } |
| 125 | 127 |
| 126 pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> { | 128 pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> { |
| 127 let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); | 129 let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); |
| 128 match sockets { | 130 match sockets { |
| 129 Ok(sockets) => Ok(sockets.take(name)), | 131 Ok(sockets) => Ok(sockets.take(name)), |
| 130 Err(error) => Err(Error(error)) | 132 Err(error) => Err(Error(error)) |
| 131 } | 133 } |
| 170 Self::LibSystemd(error) => error.source(), | 172 Self::LibSystemd(error) => error.source(), |
| 171 } | 173 } |
| 172 } | 174 } |
| 173 } | 175 } |
| 174 | 176 |
| 175 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>); | 177 pub(crate) enum Socket { |
| 178 TcpListener(std::net::TcpListener), | |
| 179 } | |
| 180 | |
| 181 impl std::convert::TryFrom<FileDescriptor> for Socket { | |
| 182 type Error = (); | |
| 183 | |
| 184 fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> { | |
| 185 use libsystemd::activation::IsType; | |
| 186 use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd}; | |
| 187 | |
| 188 fn set_cloexec(fd: std::os::unix::io::RawFd) { | |
| 189 // SAFETY: The function is a harmless syscall | |
| 190 let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) }; | |
| 191 if flags != -1 && flags & libc::FD_CLOEXEC == 0 { | |
| 192 // We ignore errors, since the FD is still usable | |
| 193 // SAFETY: socket is definitely a valid file descriptor and setting CLOEXEC is | |
| 194 // a sound operation. | |
| 195 unsafe { | |
| 196 libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC); | |
| 197 } | |
| 198 } | |
| 199 } | |
| 200 | |
| 201 if value.is_inet() { | |
| 202 // SAFETY: FileDescriptor is obtained from systemd, so it should be valid. | |
| 203 let socket = unsafe { std::net::TcpListener::from_raw_fd(value.into_raw_fd()) }; | |
| 204 set_cloexec(socket.as_raw_fd()); | |
| 205 Ok(Socket::TcpListener(socket)) | |
| 206 } else { | |
| 207 // We still need to make the filedescriptor harmless. | |
| 208 set_cloexec(value.into_raw_fd()); | |
| 209 Err(()) | |
| 210 } | |
| 211 } | |
| 212 } | |
| 213 | |
| 214 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, StoredSocket>>); | |
| 176 | 215 |
| 177 impl SystemdSockets { | 216 impl SystemdSockets { |
| 178 fn new_protected(explicit: bool) -> Result<Self, InitError> { | 217 fn new_protected(explicit: bool) -> Result<Self, InitError> { |
| 179 unsafe { Self::new(true, explicit) } | 218 unsafe { Self::new(true, explicit) } |
| 180 } | 219 } |
| 181 | 220 |
| 182 unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> { | 221 unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> { |
| 222 use std::convert::TryFrom; | |
| 223 | |
| 183 if explicit { | 224 if explicit { |
| 184 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() { | 225 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() { |
| 185 // Systemd is not used - make the map empty | 226 // Systemd is not used - make the map empty |
| 186 return Ok(SystemdSockets(Mutex::new(Default::default()))); | 227 return Ok(SystemdSockets(Mutex::new(Default::default()))); |
| 187 } | 228 } |
| 188 } | 229 } |
| 189 | 230 |
| 190 if protected { Self::check_single_thread()? } | 231 if protected { Self::check_single_thread()? } |
| 191 // MUST BE true FOR SAFETY!!! | 232 // MUST BE true FOR SAFETY!!! |
| 192 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| (name, fd)).collect(); | 233 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| { |
| 234 (name, Socket::try_from(fd)) | |
| 235 }).collect(); | |
| 193 Ok(SystemdSockets(Mutex::new(map))) | 236 Ok(SystemdSockets(Mutex::new(map))) |
| 194 } | 237 } |
| 195 | 238 |
| 196 fn check_single_thread() -> Result<(), InitError> { | 239 fn check_single_thread() -> Result<(), InitError> { |
| 197 use std::io::BufRead; | 240 use std::io::BufRead; |
| 213 line.clear(); | 256 line.clear(); |
| 214 } | 257 } |
| 215 Ok(()) | 258 Ok(()) |
| 216 } | 259 } |
| 217 | 260 |
| 218 fn take(&self, name: &str) -> Option<FileDescriptor> { | 261 fn take(&self, name: &str) -> Option<StoredSocket> { |
| 219 // MUST remove THE SOCKET FOR SAFETY!!! | 262 // MUST remove THE SOCKET FOR SAFETY!!! |
| 220 self.0.lock().expect("poisoned mutex").remove(name) | 263 self.0.lock().expect("poisoned mutex").remove(name) |
| 221 } | 264 } |
| 222 } | 265 } |
| 223 | 266 |
| 412 } | 455 } |
| 413 } | 456 } |
| 414 | 457 |
| 415 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 458 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
| 416 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { | 459 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { |
| 417 use libsystemd::activation::IsType; | 460 use systemd_sockets::Socket; |
| 418 use std::os::unix::io::{FromRawFd, IntoRawFd}; | |
| 419 | 461 |
| 420 let real_systemd_name = if prefixed { | 462 let real_systemd_name = if prefixed { |
| 421 &socket_name[SYSTEMD_PREFIX.len()..] | 463 &socket_name[SYSTEMD_PREFIX.len()..] |
| 422 } else { | 464 } else { |
| 423 &socket_name | 465 &socket_name |
| 424 }; | 466 }; |
| 425 | 467 |
| 426 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; | 468 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; |
| 427 // Safety: The environment variable is unset, so that no other calls can get the | 469 // match instead of combinators to avoid cloning socket_name |
| 428 // descriptors. The descriptors are taken from the map, not cloned, so they can't | 470 match socket { |
| 429 // be duplicated. | 471 Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))), |
| 430 unsafe { | 472 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), |
| 431 // match instead of combinators to avoid cloning socket_name | 473 None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) |
| 432 match socket { | |
| 433 Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), | |
| 434 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), | |
| 435 None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) | |
| 436 } | |
| 437 } | 474 } |
| 438 } | 475 } |
| 439 | 476 |
| 440 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it | 477 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it |
| 441 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. | 478 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. |
| 452 /// systemd socket but at that time there may be other threads running and error reporting also | 489 /// systemd socket but at that time there may be other threads running and error reporting also |
| 453 /// faces some restrictions. This function provides better control over the initialization point | 490 /// faces some restrictions. This function provides better control over the initialization point |
| 454 /// and returns a more idiomatic error type. | 491 /// and returns a more idiomatic error type. |
| 455 /// | 492 /// |
| 456 /// You should generally call this at around the top of `main`, where no threads were created yet. | 493 /// You should generally call this at around the top of `main`, where no threads were created yet. |
| 494 /// While technically, you may spawn a thread and call this function after that thread terminated, | |
| 495 /// this has the additional problem that the descriptors are still around, so if that thread (or the | |
| 496 /// current one!) forks and execs the descriptors will leak into the child. | |
| 457 #[inline] | 497 #[inline] |
| 458 pub fn init() -> Result<(), error::InitError> { | 498 pub fn init() -> Result<(), error::InitError> { |
| 459 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 499 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
| 460 { | 500 { |
| 461 // Calling with true is always sound | 501 // Calling with true is always sound |
| 475 /// | 515 /// |
| 476 /// If for any reason you're unable to call `init` in a single thread at around the top of `main` | 516 /// If for any reason you're unable to call `init` in a single thread at around the top of `main` |
| 477 /// (and this should be almost never) you may call this method if you've ensured that no other part | 517 /// (and this should be almost never) you may call this method if you've ensured that no other part |
| 478 /// of your codebase is operating on systemd-provided file descriptors stored in the environment | 518 /// of your codebase is operating on systemd-provided file descriptors stored in the environment |
| 479 /// variables. | 519 /// variables. |
| 520 /// | |
| 521 /// Note however that doing so uncovers another problem: if another thread forks and execs the | |
| 522 /// systemd file descriptors will get passed into that program! In such case you somehow need to | |
| 523 /// clean up the file descriptors yourself. | |
| 480 pub unsafe fn init_unprotected() -> Result<(), error::InitError> { | 524 pub unsafe fn init_unprotected() -> Result<(), error::InitError> { |
| 481 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 525 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
| 482 { | 526 { |
| 483 systemd_sockets::init(false).map_err(error::InitError) | 527 systemd_sockets::init(false).map_err(error::InitError) |
| 484 } | 528 } |
