Mercurial > crates > systemd-socket
view src/lib.rs @ 4:66c0e10c89fc
Support resolving hostnames
Until now the crate supported only IP addresses and systemd sockets.
This was troublesome because it prevented the popular `localhost:1234`
format. This commit changes the behavior so that if parsing of
`std::net::SocketAddr` fails it attempts to parse it as `hostname:port`.
`bind_*()` methods were also modified to be async because of this.
author | Martin Habovstiak <martin.habovstiak@gmail.com> |
---|---|
date | Fri, 27 Nov 2020 15:05:19 +0100 |
parents | 0edcde404b02 |
children | a7893294e9b2 |
line wrap: on
line source
//! A convenience crate for optionally supporting systemd socket activation. //! //! ## About //! //! The goal of this crate is to make socket activation with systemd in your project trivial. //! It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std` //! but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name. //! Then it provides a method to bind the address which will return the socket from systemd if available. //! //! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag. //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. //! //! Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are //! activated. //! //! ## Example //! //! ```no_run //! use systemd_socket::SocketAddr; //! use std::convert::TryFrom; //! use std::io::Write; //! //! let mut args = std::env::args_os(); //! let program_name = args.next().expect("unknown program name"); //! let socket_addr = args.next().expect("missing socket address"); //! let socket_addr = SocketAddr::try_from(socket_addr).expect("failed to parse socket address"); //! let socket = socket_addr.bind().expect("failed to bind socket"); //! //! loop { //! let _ = socket //! .accept() //! .expect("failed to accept connection") //! .0 //! .write_all(b"Hello world!") //! .map_err(|err| eprintln!("Failed to send {}", err)); //! } //! ``` //! //! ## Features //! //! * `serde` - implements `serde::Deserialize` for `SocketAddr` //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` //! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr` //! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr` //! * `async_std` - adds `bind_async_std` method to `SocketAddr` //! //! ## MSRV //! //! This crate must always compile with the latest Rust available in the latest Debian stable. //! That is currently Rust 1.41.1. (Debian 10 - Buster) #![deny(missing_docs)] pub mod error; mod resolv_addr; use std::convert::{TryFrom, TryInto}; use std::fmt; use std::ffi::{OsStr, OsString}; use crate::error::*; use crate::resolv_addr::ResolvAddr; pub(crate) mod systemd_sockets { use std::fmt; use std::sync::Mutex; use libsystemd::activation::FileDescriptor; use libsystemd::errors::Error as LibSystemdError; use libsystemd::errors::Result as LibSystemdResult; #[derive(Debug)] pub(crate) struct Error(&'static Mutex<LibSystemdError>); impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&*self.0.lock().expect("mutex poisoned"), f) } } // No source we can't keep the mutex locked impl std::error::Error for Error {} pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> { match &*SYSTEMD_SOCKETS { Ok(sockets) => Ok(sockets.take(name)), Err(error) => Err(Error(error)) } } struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>); impl SystemdSockets { fn new() -> LibSystemdResult<Self> { // MUST BE true FOR SAFETY!!! let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ true)?.into_iter().map(|(fd, name)| (name, fd)).collect(); Ok(SystemdSockets(Mutex::new(map))) } fn take(&self, name: &str) -> Option<FileDescriptor> { // MUST remove THE SOCKET FOR SAFETY!!! self.0.lock().expect("poisoned mutex").remove(name) } } lazy_static::lazy_static! { // We don't panic in order to let the application handle the error later static ref SYSTEMD_SOCKETS: Result<SystemdSockets, Mutex<LibSystemdError>> = SystemdSockets::new().map_err(Mutex::new); } } /// Socket address that can be an ordinary address or a systemd socket /// /// This is the core type of this crate that abstracts possible addresses. /// It can be (fallibly) converted from various types of strings or deserialized with `serde`. /// After it's created, it can be bound as `TcpListener` from `std` or even `tokio` or `async_std` /// if the appropriate feature is enabled. /// /// Optional dependencies on `parse_arg` and `serde` make it trivial to use with /// [`configure_me`](https://crates.io/crates/configure_me). #[derive(Debug)] #[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))] pub struct SocketAddr(SocketAddrInner); impl SocketAddr { /// Creates `std::net::TcpListener` /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. pub fn bind(self) -> Result<std::net::TcpListener, BindError> { match self.0 { SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { Ok(socket) => Ok(socket), Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), }, SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { Ok(socket) => Ok(socket), Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), }, SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket), } } /// Creates `tokio::net::TcpListener` /// /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "tokio_0_2")] pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { match self.0 { SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await { Ok(socket) => Ok(socket), Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), }, SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { Ok(socket) => Ok(socket), Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), }, SocketAddrInner::Systemd(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name)?; socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) }, } } /// Creates `tokio::net::TcpListener` /// /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "tokio_0_3")] pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { match self.0 { SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await { Ok(socket) => Ok(socket), Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), }, SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { Ok(socket) => Ok(socket), Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), }, SocketAddrInner::Systemd(socket_name) => { let (socket, addr) = Self::get_systemd(socket_name)?; socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) }, } } /// Creates `async_std::net::TcpListener` /// /// To be specific, it binds the socket or converts systemd socket to `async_std` socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "async-std")] pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { match self.0 { SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await { Ok(socket) => Ok(socket), Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), }, SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { Ok(socket) => Ok(socket), Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), }, SocketAddrInner::Systemd(socket_name) => { let (socket, _) = Self::get_systemd(socket_name)?; Ok(socket.into()) }, } } // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan // rules. fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { if string.starts_with(SYSTEMD_PREFIX) { let name_len = string.len() - SYSTEMD_PREFIX.len(); match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))), None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()), Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), } } else { match string.parse() { Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), } } } fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { use libsystemd::activation::IsType; use std::os::unix::io::{FromRawFd, IntoRawFd}; let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; // Safety: The environment variable is unset, so that no other calls can get the // descriptors. The descriptors are taken from the map, not cloned, so they can't // be duplicated. unsafe { // match instead of combinators to avoid cloning socket_name match socket { Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) } } } } /// Displays the address in format that can be parsed again. /// /// **Important: While I don't expect this impl to change, don't rely on it!** /// It should be used mostly for debugging/logging. impl fmt::Display for SocketAddr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(&self.0, f) } } impl fmt::Display for SocketAddrInner { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), } } } // PartialEq for testing, I'm not convinced it should be exposed #[derive(Debug, PartialEq)] enum SocketAddrInner { Ordinary(std::net::SocketAddr), WithHostname(resolv_addr::ResolvAddr), Systemd(String), } const SYSTEMD_PREFIX: &str = "systemd://"; impl std::str::FromStr for SocketAddr { type Err = ParseError; fn from_str(s: &str) -> Result<Self, Self::Err> { SocketAddr::try_from_generic(s) } } impl<'a> TryFrom<&'a str> for SocketAddr { type Error = ParseError; fn try_from(s: &'a str) -> Result<Self, Self::Error> { SocketAddr::try_from_generic(s) } } impl TryFrom<String> for SocketAddr { type Error = ParseError; fn try_from(s: String) -> Result<Self, Self::Error> { SocketAddr::try_from_generic(s) } } impl<'a> TryFrom<&'a OsStr> for SocketAddr { type Error = ParseOsStrError; fn try_from(s: &'a OsStr) -> Result<Self, Self::Error> { s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into) } } impl TryFrom<OsString> for SocketAddr { type Error = ParseOsStrError; fn try_from(s: OsString) -> Result<Self, Self::Error> { s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into) } } #[cfg(feature = "serde")] impl<'a> TryFrom<serde_str_helpers::DeserBorrowStr<'a>> for SocketAddr { type Error = ParseError; fn try_from(s: serde_str_helpers::DeserBorrowStr<'a>) -> Result<Self, Self::Error> { SocketAddr::try_from_generic(std::borrow::Cow::from(s)) } } #[cfg(feature = "parse_arg")] impl parse_arg::ParseArg for SocketAddr { type Error = ParseOsStrError; fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result { std::net::SocketAddr::describe_type(&mut writer)?; write!(writer, " or a systemd socket name prefixed with systemd://") } fn parse_arg(arg: &OsStr) -> Result<Self, Self::Error> { arg.try_into() } fn parse_owned_arg(arg: OsString) -> Result<Self, Self::Error> { arg.try_into() } } #[cfg(test)] mod tests { use super::{SocketAddr, SocketAddrInner}; #[test] fn parse_ordinary() { assert_eq!("127.0.0.1:42".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into())); } #[test] fn parse_systemd() { assert_eq!("systemd://foo".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned())); } #[test] #[should_panic] fn parse_systemd_fail_control() { "systemd://foo\n".parse::<SocketAddr>().unwrap(); } #[test] #[should_panic] fn parse_systemd_fail_colon() { "systemd://foo:".parse::<SocketAddr>().unwrap(); } #[test] #[should_panic] fn parse_systemd_fail_non_ascii() { "systemd://fooĆ”".parse::<SocketAddr>().unwrap(); } #[test] #[should_panic] fn parse_systemd_fail_too_long() { "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap(); } }