Mercurial > crates > systemd-socket
changeset 4:66c0e10c89fc
Support resolving hostnames
Until now the crate supported only IP addresses and systemd sockets.
This was troublesome because it prevented the popular `localhost:1234`
format. This commit changes the behavior so that if parsing of
`std::net::SocketAddr` fails it attempts to parse it as `hostname:port`.
`bind_*()` methods were also modified to be async because of this.
author | Martin Habovstiak <martin.habovstiak@gmail.com> |
---|---|
date | Fri, 27 Nov 2020 15:05:19 +0100 |
parents | 0edcde404b02 |
children | 27456533853e |
files | Cargo.toml README.md src/error.rs src/lib.rs src/resolv_addr.rs tests/ordinary.rs |
diffstat | 6 files changed, 171 insertions(+), 48 deletions(-) [+] |
line wrap: on
line diff
--- a/Cargo.toml Fri Nov 27 12:56:37 2020 +0100 +++ b/Cargo.toml Fri Nov 27 15:05:19 2020 +0100 @@ -16,6 +16,6 @@ parse_arg = { version = "0.1.4", optional = true } libsystemd = "0.2.1" lazy_static = "1.4.0" -tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp"] } +tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp", "dns"] } tokio_0_3 = { package = "tokio", version = "0.3", optional = true, features = ["net"] } async-std = { version = "1.7.0", optional = true }
--- a/README.md Fri Nov 27 12:56:37 2020 +0100 +++ b/README.md Fri Nov 27 15:05:19 2020 +0100 @@ -13,7 +13,8 @@ Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. -Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and `async_std` sockets. +Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are +activated. ## Example @@ -42,9 +43,9 @@ * `serde` - implements `serde::Deserialize` for `SocketAddr` * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` -* `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr` -* `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr` -* `async_std` - adds `bind_async_std` convenience method to `SocketAddr` +* `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr` +* `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr` +* `async_std` - adds `bind_async_std` method to `SocketAddr` ## MSRV
--- a/src/error.rs Fri Nov 27 12:56:37 2020 +0100 +++ b/src/error.rs Fri Nov 27 15:05:19 2020 +0100 @@ -18,7 +18,7 @@ #[derive(Debug, Error)] pub(crate) enum ParseErrorInner { #[error("failed to parse socket address")] - SocketAddr(std::net::AddrParseError), + ResolvAddr(#[from] crate::resolv_addr::ResolvAddrError), #[error("invalid character '{c}' in systemd socket name {string} at position {pos}")] InvalidCharacter { string: String, c: char, pos: usize, }, #[error("systemd socket name {string} is {len} characters long which is more than the limit 255")] @@ -58,6 +58,8 @@ pub(crate) enum BindErrorInner { #[error("failed to bind {addr}")] BindFailed { addr: std::net::SocketAddr, #[source] error: io::Error, }, + #[error("failed to bind {addr}")] + BindOrResolvFailed { addr: crate::resolv_addr::ResolvAddr, #[source] error: io::Error, }, #[error("failed to receive descriptors with names")] ReceiveDescriptors(#[source] crate::systemd_sockets::Error), #[error("missing systemd socket {0} - a typo or an attempt to bind twice")]
--- a/src/lib.rs Fri Nov 27 12:56:37 2020 +0100 +++ b/src/lib.rs Fri Nov 27 15:05:19 2020 +0100 @@ -11,8 +11,8 @@ //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. //! -//! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and -//! `async_std` sockets if the appropriate features are activated. +//! Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are +//! activated. //! //! ## Example //! @@ -41,9 +41,9 @@ //! //! * `serde` - implements `serde::Deserialize` for `SocketAddr` //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` -//! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr` -//! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr` -//! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr` +//! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr` +//! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr` +//! * `async_std` - adds `bind_async_std` method to `SocketAddr` //! //! ## MSRV //! @@ -54,11 +54,13 @@ #![deny(missing_docs)] pub mod error; +mod resolv_addr; use std::convert::{TryFrom, TryInto}; use std::fmt; use std::ffi::{OsStr, OsString}; use crate::error::*; +use crate::resolv_addr::ResolvAddr; pub(crate) mod systemd_sockets { use std::fmt; @@ -126,43 +128,89 @@ /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. pub fn bind(self) -> Result<std::net::TcpListener, BindError> { - self._bind().map(|(socket, _)| socket) + match self.0 { + SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), + }, + SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), + }, + SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket), + } } /// Creates `tokio::net::TcpListener` /// - /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket. + /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "tokio_0_2")] - pub fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { - let (socket, addr) = self._bind()?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), + }, + SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), + }, + SocketAddrInner::Systemd(socket_name) => { + let (socket, addr) = Self::get_systemd(socket_name)?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + }, + } } /// Creates `tokio::net::TcpListener` /// - /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket. + /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "tokio_0_3")] - pub fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { - let (socket, addr) = self._bind()?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), + }, + SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), + }, + SocketAddrInner::Systemd(socket_name) => { + let (socket, addr) = Self::get_systemd(socket_name)?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + }, + } } /// Creates `async_std::net::TcpListener` /// - /// To be specific, it binds the socket and converts it to `async_std` socket. + /// To be specific, it binds the socket or converts systemd socket to `async_std` socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "async-std")] - pub fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { - let (socket, _) = self._bind()?; - Ok(socket.into()) + pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), + }, + SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), + }, + SocketAddrInner::Systemd(socket_name) => { + let (socket, _) = Self::get_systemd(socket_name)?; + Ok(socket.into()) + }, + } } // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan @@ -176,33 +224,28 @@ Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), } } else { - Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?) + match string.parse() { + Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), + Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), + } } } - fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { - match self.0 { - SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { - Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))), - Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), - }, - SocketAddrInner::Systemd(socket_name) => { - use libsystemd::activation::IsType; - use std::os::unix::io::{FromRawFd, IntoRawFd}; + fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + use libsystemd::activation::IsType; + use std::os::unix::io::{FromRawFd, IntoRawFd}; - let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; - // Safety: The environment variable is unset, so that no other calls can get the - // descriptors. The descriptors are taken from the map, not cloned, so they can't - // be duplicated. - unsafe { - // match instead of combinators to avoid cloning socket_name - match socket { - Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), - Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), - None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) - } - } - }, + let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; + // Safety: The environment variable is unset, so that no other calls can get the + // descriptors. The descriptors are taken from the map, not cloned, so they can't + // be duplicated. + unsafe { + // match instead of combinators to avoid cloning socket_name + match socket { + Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), + Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), + None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) + } } } } @@ -222,6 +265,7 @@ match self { SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), + SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), } } } @@ -230,6 +274,7 @@ #[derive(Debug, PartialEq)] enum SocketAddrInner { Ordinary(std::net::SocketAddr), + WithHostname(resolv_addr::ResolvAddr), Systemd(String), }
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/resolv_addr.rs Fri Nov 27 15:05:19 2020 +0100 @@ -0,0 +1,75 @@ +use thiserror::Error; +use std::fmt; + +#[derive(Debug, PartialEq)] +pub(crate) struct ResolvAddr(String); + +impl ResolvAddr { + pub(crate) fn as_str(&self) -> &str { + &self.0 + } + + pub(crate) fn try_from_generic<T: std::ops::Deref<Target=str> + Into<String>>(string: T) -> Result<Self, ResolvAddrError> { + // can't use a combinator due to borrowing + let colon = match string.rfind(':') { + Some(colon) => colon, + None => return Err(ResolvAddrError::MissingPort(string.into())), + }; + + let (hostname, port) = string.split_at(colon); + + if let Err(error) = port[1..].parse::<u16>() { + return Err(ResolvAddrError::InvalidPort { string: string.into(), error, }); + } + + let len = hostname.len(); + if len > 253 { + return Err(ResolvAddrError::TooLong { string: string.into(), len, } ) + } + + let mut label_start = 0usize; + + for (i, c) in hostname.chars().enumerate() { + match c { + '.' => { + if i - label_start == 0 { + return Err(ResolvAddrError::EmptyLabel { string: string.into(), label_start, }); + } + + label_start = i + 1; + }, + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => (), + _ => return Err(ResolvAddrError::InvalidCharacter { string: string.into(), c, pos: i, }), + } + + if i - label_start > 63 { + return Err(ResolvAddrError::LongLabel { string: string.into(), label_start, label_end: i, }); + } + } + + Ok(ResolvAddr(string.into())) + } +} + +impl fmt::Display for ResolvAddr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + + +#[derive(Debug, Error)] +pub(crate) enum ResolvAddrError { + #[error("hostname {string} has {len} character which exceeds the limit of 253")] + TooLong { string: String, len: usize }, + #[error("invalid character {c} in hostname {string} at position {pos}")] + InvalidCharacter { string: String, pos: usize, c: char, }, + #[error("hostname {string} contains a label {} at position {label_start} which is {} characters long - more than the limit 63", &string[(*label_start)..(*label_end)], label_end - label_start)] + LongLabel { string: String, label_start: usize, label_end: usize, }, + #[error("hostname {string} contains an empty label at position {label_start}")] + EmptyLabel { string: String, label_start: usize, }, + #[error("the address {0} is missing a port")] + MissingPort(String), + #[error("failed to parse port numer in the address {string}")] + InvalidPort { string: String, error: std::num::ParseIntError, }, +}
--- a/tests/ordinary.rs Fri Nov 27 12:56:37 2020 +0100 +++ b/tests/ordinary.rs Fri Nov 27 15:05:19 2020 +0100 @@ -7,7 +7,7 @@ enum Test {} impl comm::Test for Test { - const SOCKET_ADDR: &'static str = "127.0.0.1:4242"; + const SOCKET_ADDR: &'static str = "localhost:4242"; fn spawn_slave(program_name: &OsStr) -> io::Result<Child> { Command::new(program_name)