changeset 17:dfb727367934

Allow specifying systemd name directly When an application has systemd name by parsing a configuration directly it can use the function added in this change to construct `SocketAddr` without having to allocate an intermediate string.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Tue, 22 Dec 2020 14:15:49 +0100
parents bc76507dd878
children db1dc99252e2
files src/lib.rs
diffstat 1 files changed, 82 insertions(+), 20 deletions(-) [+]
line wrap: on
line diff
--- a/src/lib.rs	Tue Dec 22 13:58:47 2020 +0100
+++ b/src/lib.rs	Tue Dec 22 14:15:49 2020 +0100
@@ -133,6 +133,36 @@
 pub struct SocketAddr(SocketAddrInner);
 
 impl SocketAddr {
+    /// Creates SocketAddr from systemd name directly, without requiring `systemd://` prefix.
+    ///
+    /// Always fails with systemd unsupported error if systemd is not supported.
+    pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> {
+        Self::inner_from_systemd_name(name.into(), false)
+    }
+
+    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
+    fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> {
+        let real_systemd_name = if prefixed {
+            &name[SYSTEMD_PREFIX.len()..]
+        } else {
+            &name
+        };
+
+        let name_len = real_systemd_name.len();
+        match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
+            None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))),
+            None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))),
+            None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()),
+            Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()),
+        }
+    }
+
+
+    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
+    fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> {
+        Err(ParseError(ParseErrorInner::SystemdUnsupported(name)))
+    }
+
     /// Creates `std::net::TcpListener`
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
@@ -147,7 +177,8 @@
                 Ok(socket) => Ok(socket),
                 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
             },
-            SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket),
+            SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket),
+            SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket),
         }
     }
 
@@ -169,7 +200,11 @@
                 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
             },
             SocketAddrInner::Systemd(socket_name) => {
-                let (socket, addr) = Self::get_systemd(socket_name)?;
+                let (socket, addr) = Self::get_systemd(socket_name, true)?;
+                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+            },
+            SocketAddrInner::SystemdNoPrefix(socket_name) => {
+                let (socket, addr) = Self::get_systemd(socket_name, false)?;
                 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
             },
         }
@@ -193,7 +228,11 @@
                 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
             },
             SocketAddrInner::Systemd(socket_name) => {
-                let (socket, addr) = Self::get_systemd(socket_name)?;
+                let (socket, addr) = Self::get_systemd(socket_name, true)?;
+                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+            },
+            SocketAddrInner::SystemdNoPrefix(socket_name) => {
+                let (socket, addr) = Self::get_systemd(socket_name, false)?;
                 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
             },
         }
@@ -217,7 +256,11 @@
                 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
             },
             SocketAddrInner::Systemd(socket_name) => {
-                let (socket, _) = Self::get_systemd(socket_name)?;
+                let (socket, _) = Self::get_systemd(socket_name, true)?;
+                Ok(socket.into())
+            },
+            SocketAddrInner::SystemdNoPrefix(socket_name) => {
+                let (socket, _) = Self::get_systemd(socket_name, false)?;
                 Ok(socket.into())
             },
         }
@@ -227,19 +270,7 @@
     // rules.
     fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
         if string.starts_with(SYSTEMD_PREFIX) {
-            #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
-            {
-                let name_len = string.len() - SYSTEMD_PREFIX.len();
-                match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
-                    None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))),
-                    None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()),
-                    Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()),
-                }
-            }
-            #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
-            {
-                Err(ParseErrorInner::SystemdUnsupported(string.into()).into())
-            }
+            Self::inner_from_systemd_name(string.into(), true)
         } else {
             match string.parse() {
                 Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
@@ -249,11 +280,17 @@
     }
 
     #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
-    fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
+    fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
         use libsystemd::activation::IsType;
         use std::os::unix::io::{FromRawFd, IntoRawFd};
 
-        let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
+        let real_systemd_name = if prefixed {
+            &socket_name[SYSTEMD_PREFIX.len()..]
+        } else {
+            &socket_name
+        };
+
+        let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
         // Safety: The environment variable is unset, so that no other calls can get the
         // descriptors. The descriptors are taken from the map, not cloned, so they can't
         // be duplicated.
@@ -270,7 +307,7 @@
     // This approach makes the rest of the code much simpler as it doesn't require sprinkling it
     // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute.
     #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
-    fn get_systemd(socket_name: Never) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
+    fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
         match socket_name {}
     }
 }
@@ -290,6 +327,7 @@
         match self {
             SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
             SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
+            SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr),
             SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
         }
     }
@@ -305,6 +343,12 @@
     #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
     #[allow(dead_code)]
     Systemd(Never),
+    #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
+    #[allow(dead_code)]
+    SystemdNoPrefix(String),
+    #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
+    #[allow(dead_code)]
+    SystemdNoPrefix(Never),
 }
 
 const SYSTEMD_PREFIX: &str = "systemd://";
@@ -439,4 +483,22 @@
     fn parse_systemd_fail_too_long() {
         "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap();
     }
+
+    #[test]
+    #[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)]
+    fn no_prefix_parse_systemd() {
+        SocketAddr::from_systemd_name("foo").unwrap();
+    }
+
+    #[test]
+    #[should_panic]
+    fn no_prefix_parse_systemd_fail_non_ascii() {
+        SocketAddr::from_systemd_name("fooĆ”").unwrap();
+    }
+
+    #[test]
+    #[should_panic]
+    fn no_prefix_parse_systemd_fail_too_long() {
+        SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap();
+    }
 }