Mercurial > crates > systemd-socket
changeset 17:dfb727367934
Allow specifying systemd name directly
When an application has systemd name by parsing a configuration directly
it can use the function added in this change to construct `SocketAddr`
without having to allocate an intermediate string.
author | Martin Habovstiak <martin.habovstiak@gmail.com> |
---|---|
date | Tue, 22 Dec 2020 14:15:49 +0100 |
parents | bc76507dd878 |
children | db1dc99252e2 |
files | src/lib.rs |
diffstat | 1 files changed, 82 insertions(+), 20 deletions(-) [+] |
line wrap: on
line diff
--- a/src/lib.rs Tue Dec 22 13:58:47 2020 +0100 +++ b/src/lib.rs Tue Dec 22 14:15:49 2020 +0100 @@ -133,6 +133,36 @@ pub struct SocketAddr(SocketAddrInner); impl SocketAddr { + /// Creates SocketAddr from systemd name directly, without requiring `systemd://` prefix. + /// + /// Always fails with systemd unsupported error if systemd is not supported. + pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> { + Self::inner_from_systemd_name(name.into(), false) + } + + #[cfg(all(target_os = "linux", feature = "enable_systemd"))] + fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> { + let real_systemd_name = if prefixed { + &name[SYSTEMD_PREFIX.len()..] + } else { + &name + }; + + let name_len = real_systemd_name.len(); + match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { + None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))), + None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))), + None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()), + Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()), + } + } + + + #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] + fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> { + Err(ParseError(ParseErrorInner::SystemdUnsupported(name))) + } + /// Creates `std::net::TcpListener` /// /// This method either `binds` the socket, if the address was provided or uses systemd socket @@ -147,7 +177,8 @@ Ok(socket) => Ok(socket), Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), }, - SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket), + SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket), + SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket), } } @@ -169,7 +200,11 @@ Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), }, SocketAddrInner::Systemd(socket_name) => { - let (socket, addr) = Self::get_systemd(socket_name)?; + let (socket, addr) = Self::get_systemd(socket_name, true)?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + }, + SocketAddrInner::SystemdNoPrefix(socket_name) => { + let (socket, addr) = Self::get_systemd(socket_name, false)?; socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) }, } @@ -193,7 +228,11 @@ Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), }, SocketAddrInner::Systemd(socket_name) => { - let (socket, addr) = Self::get_systemd(socket_name)?; + let (socket, addr) = Self::get_systemd(socket_name, true)?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + }, + SocketAddrInner::SystemdNoPrefix(socket_name) => { + let (socket, addr) = Self::get_systemd(socket_name, false)?; socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) }, } @@ -217,7 +256,11 @@ Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), }, SocketAddrInner::Systemd(socket_name) => { - let (socket, _) = Self::get_systemd(socket_name)?; + let (socket, _) = Self::get_systemd(socket_name, true)?; + Ok(socket.into()) + }, + SocketAddrInner::SystemdNoPrefix(socket_name) => { + let (socket, _) = Self::get_systemd(socket_name, false)?; Ok(socket.into()) }, } @@ -227,19 +270,7 @@ // rules. fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { if string.starts_with(SYSTEMD_PREFIX) { - #[cfg(all(target_os = "linux", feature = "enable_systemd"))] - { - let name_len = string.len() - SYSTEMD_PREFIX.len(); - match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { - None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))), - None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()), - Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), - } - } - #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] - { - Err(ParseErrorInner::SystemdUnsupported(string.into()).into()) - } + Self::inner_from_systemd_name(string.into(), true) } else { match string.parse() { Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), @@ -249,11 +280,17 @@ } #[cfg(all(target_os = "linux", feature = "enable_systemd"))] - fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { use libsystemd::activation::IsType; use std::os::unix::io::{FromRawFd, IntoRawFd}; - let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; + let real_systemd_name = if prefixed { + &socket_name[SYSTEMD_PREFIX.len()..] + } else { + &socket_name + }; + + let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; // Safety: The environment variable is unset, so that no other calls can get the // descriptors. The descriptors are taken from the map, not cloned, so they can't // be duplicated. @@ -270,7 +307,7 @@ // This approach makes the rest of the code much simpler as it doesn't require sprinkling it // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] - fn get_systemd(socket_name: Never) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { match socket_name {} } } @@ -290,6 +327,7 @@ match self { SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), + SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr), SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), } } @@ -305,6 +343,12 @@ #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] #[allow(dead_code)] Systemd(Never), + #[cfg(all(target_os = "linux", feature = "enable_systemd"))] + #[allow(dead_code)] + SystemdNoPrefix(String), + #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] + #[allow(dead_code)] + SystemdNoPrefix(Never), } const SYSTEMD_PREFIX: &str = "systemd://"; @@ -439,4 +483,22 @@ fn parse_systemd_fail_too_long() { "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap(); } + + #[test] + #[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)] + fn no_prefix_parse_systemd() { + SocketAddr::from_systemd_name("foo").unwrap(); + } + + #[test] + #[should_panic] + fn no_prefix_parse_systemd_fail_non_ascii() { + SocketAddr::from_systemd_name("fooĆ”").unwrap(); + } + + #[test] + #[should_panic] + fn no_prefix_parse_systemd_fail_too_long() { + SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap(); + } }