Mercurial > crates > systemd-socket
comparison src/lib.rs @ 25:8e20daee41ed
Set CLOEXEC flag on the descriptors
The received systemd descriptors don't have `O_CLOEXEC` set because they
are received over `exec`. Thus if the process executes a child the child
gets polluted with the descriptors.
To prevent this, we set `O_CLOEXEC` during initialization. However this
also required restructuring of the code because `libsystemd` doesn't
provide temporary access to the descriptors - only permanent one. Thus
we have to "validate" the descriptors eagerly. We still store the
invalid ones as errors to make sure the errors get reported accurately.
author | Martin Habovstiak <martin.habovstiak@gmail.com> |
---|---|
date | Fri, 28 Feb 2025 21:11:19 +0100 |
parents | 1941e9d9819c |
children | 0feab4f4c2ce |
comparison
equal
deleted
inserted
replaced
24:1941e9d9819c | 25:8e20daee41ed |
---|---|
114 } | 114 } |
115 Ok(()) | 115 Ok(()) |
116 } | 116 } |
117 } | 117 } |
118 | 118 |
119 type StoredSocket = Result<Socket, ()>; | |
120 | |
119 // No source we can't keep the mutex locked | 121 // No source we can't keep the mutex locked |
120 impl std::error::Error for Error {} | 122 impl std::error::Error for Error {} |
121 | 123 |
122 pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> { | 124 pub(crate) unsafe fn init(protected: bool) -> Result<(), InitError> { |
123 SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop) | 125 SYSTEMD_SOCKETS.get_or_try_init(|| SystemdSockets::new(protected, true).map(Ok)).map(drop) |
124 } | 126 } |
125 | 127 |
126 pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> { | 128 pub(crate) fn take(name: &str) -> Result<Option<StoredSocket>, Error> { |
127 let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); | 129 let sockets = SYSTEMD_SOCKETS.get_or_init(|| SystemdSockets::new_protected(false).map_err(Mutex::new)); |
128 match sockets { | 130 match sockets { |
129 Ok(sockets) => Ok(sockets.take(name)), | 131 Ok(sockets) => Ok(sockets.take(name)), |
130 Err(error) => Err(Error(error)) | 132 Err(error) => Err(Error(error)) |
131 } | 133 } |
170 Self::LibSystemd(error) => error.source(), | 172 Self::LibSystemd(error) => error.source(), |
171 } | 173 } |
172 } | 174 } |
173 } | 175 } |
174 | 176 |
175 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>); | 177 pub(crate) enum Socket { |
178 TcpListener(std::net::TcpListener), | |
179 } | |
180 | |
181 impl std::convert::TryFrom<FileDescriptor> for Socket { | |
182 type Error = (); | |
183 | |
184 fn try_from(value: FileDescriptor) -> Result<Self, Self::Error> { | |
185 use libsystemd::activation::IsType; | |
186 use std::os::unix::io::{FromRawFd, IntoRawFd, AsRawFd}; | |
187 | |
188 fn set_cloexec(fd: std::os::unix::io::RawFd) { | |
189 // SAFETY: The function is a harmless syscall | |
190 let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) }; | |
191 if flags != -1 && flags & libc::FD_CLOEXEC == 0 { | |
192 // We ignore errors, since the FD is still usable | |
193 // SAFETY: socket is definitely a valid file descriptor and setting CLOEXEC is | |
194 // a sound operation. | |
195 unsafe { | |
196 libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC); | |
197 } | |
198 } | |
199 } | |
200 | |
201 if value.is_inet() { | |
202 // SAFETY: FileDescriptor is obtained from systemd, so it should be valid. | |
203 let socket = unsafe { std::net::TcpListener::from_raw_fd(value.into_raw_fd()) }; | |
204 set_cloexec(socket.as_raw_fd()); | |
205 Ok(Socket::TcpListener(socket)) | |
206 } else { | |
207 // We still need to make the filedescriptor harmless. | |
208 set_cloexec(value.into_raw_fd()); | |
209 Err(()) | |
210 } | |
211 } | |
212 } | |
213 | |
214 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, StoredSocket>>); | |
176 | 215 |
177 impl SystemdSockets { | 216 impl SystemdSockets { |
178 fn new_protected(explicit: bool) -> Result<Self, InitError> { | 217 fn new_protected(explicit: bool) -> Result<Self, InitError> { |
179 unsafe { Self::new(true, explicit) } | 218 unsafe { Self::new(true, explicit) } |
180 } | 219 } |
181 | 220 |
182 unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> { | 221 unsafe fn new(protected: bool, explicit: bool) -> Result<Self, InitError> { |
222 use std::convert::TryFrom; | |
223 | |
183 if explicit { | 224 if explicit { |
184 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() { | 225 if std::env::var_os("LISTEN_PID").is_none() && std::env::var_os("LISTEN_FDS").is_none() && std::env::var_os("LISTEN_FDNAMES").is_none() { |
185 // Systemd is not used - make the map empty | 226 // Systemd is not used - make the map empty |
186 return Ok(SystemdSockets(Mutex::new(Default::default()))); | 227 return Ok(SystemdSockets(Mutex::new(Default::default()))); |
187 } | 228 } |
188 } | 229 } |
189 | 230 |
190 if protected { Self::check_single_thread()? } | 231 if protected { Self::check_single_thread()? } |
191 // MUST BE true FOR SAFETY!!! | 232 // MUST BE true FOR SAFETY!!! |
192 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| (name, fd)).collect(); | 233 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ protected)?.into_iter().map(|(fd, name)| { |
234 (name, Socket::try_from(fd)) | |
235 }).collect(); | |
193 Ok(SystemdSockets(Mutex::new(map))) | 236 Ok(SystemdSockets(Mutex::new(map))) |
194 } | 237 } |
195 | 238 |
196 fn check_single_thread() -> Result<(), InitError> { | 239 fn check_single_thread() -> Result<(), InitError> { |
197 use std::io::BufRead; | 240 use std::io::BufRead; |
213 line.clear(); | 256 line.clear(); |
214 } | 257 } |
215 Ok(()) | 258 Ok(()) |
216 } | 259 } |
217 | 260 |
218 fn take(&self, name: &str) -> Option<FileDescriptor> { | 261 fn take(&self, name: &str) -> Option<StoredSocket> { |
219 // MUST remove THE SOCKET FOR SAFETY!!! | 262 // MUST remove THE SOCKET FOR SAFETY!!! |
220 self.0.lock().expect("poisoned mutex").remove(name) | 263 self.0.lock().expect("poisoned mutex").remove(name) |
221 } | 264 } |
222 } | 265 } |
223 | 266 |
412 } | 455 } |
413 } | 456 } |
414 | 457 |
415 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 458 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
416 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { | 459 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { |
417 use libsystemd::activation::IsType; | 460 use systemd_sockets::Socket; |
418 use std::os::unix::io::{FromRawFd, IntoRawFd}; | |
419 | 461 |
420 let real_systemd_name = if prefixed { | 462 let real_systemd_name = if prefixed { |
421 &socket_name[SYSTEMD_PREFIX.len()..] | 463 &socket_name[SYSTEMD_PREFIX.len()..] |
422 } else { | 464 } else { |
423 &socket_name | 465 &socket_name |
424 }; | 466 }; |
425 | 467 |
426 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; | 468 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; |
427 // Safety: The environment variable is unset, so that no other calls can get the | 469 // match instead of combinators to avoid cloning socket_name |
428 // descriptors. The descriptors are taken from the map, not cloned, so they can't | 470 match socket { |
429 // be duplicated. | 471 Some(Ok(Socket::TcpListener(socket))) => Ok((socket, SocketAddrInner::Systemd(socket_name))), |
430 unsafe { | 472 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), |
431 // match instead of combinators to avoid cloning socket_name | 473 None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) |
432 match socket { | |
433 Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), | |
434 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), | |
435 None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) | |
436 } | |
437 } | 474 } |
438 } | 475 } |
439 | 476 |
440 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it | 477 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it |
441 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. | 478 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. |
452 /// systemd socket but at that time there may be other threads running and error reporting also | 489 /// systemd socket but at that time there may be other threads running and error reporting also |
453 /// faces some restrictions. This function provides better control over the initialization point | 490 /// faces some restrictions. This function provides better control over the initialization point |
454 /// and returns a more idiomatic error type. | 491 /// and returns a more idiomatic error type. |
455 /// | 492 /// |
456 /// You should generally call this at around the top of `main`, where no threads were created yet. | 493 /// You should generally call this at around the top of `main`, where no threads were created yet. |
494 /// While technically, you may spawn a thread and call this function after that thread terminated, | |
495 /// this has the additional problem that the descriptors are still around, so if that thread (or the | |
496 /// current one!) forks and execs the descriptors will leak into the child. | |
457 #[inline] | 497 #[inline] |
458 pub fn init() -> Result<(), error::InitError> { | 498 pub fn init() -> Result<(), error::InitError> { |
459 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 499 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
460 { | 500 { |
461 // Calling with true is always sound | 501 // Calling with true is always sound |
475 /// | 515 /// |
476 /// If for any reason you're unable to call `init` in a single thread at around the top of `main` | 516 /// If for any reason you're unable to call `init` in a single thread at around the top of `main` |
477 /// (and this should be almost never) you may call this method if you've ensured that no other part | 517 /// (and this should be almost never) you may call this method if you've ensured that no other part |
478 /// of your codebase is operating on systemd-provided file descriptors stored in the environment | 518 /// of your codebase is operating on systemd-provided file descriptors stored in the environment |
479 /// variables. | 519 /// variables. |
520 /// | |
521 /// Note however that doing so uncovers another problem: if another thread forks and execs the | |
522 /// systemd file descriptors will get passed into that program! In such case you somehow need to | |
523 /// clean up the file descriptors yourself. | |
480 pub unsafe fn init_unprotected() -> Result<(), error::InitError> { | 524 pub unsafe fn init_unprotected() -> Result<(), error::InitError> { |
481 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 525 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
482 { | 526 { |
483 systemd_sockets::init(false).map_err(error::InitError) | 527 systemd_sockets::init(false).map_err(error::InitError) |
484 } | 528 } |