changeset 4:66c0e10c89fc

Support resolving hostnames Until now the crate supported only IP addresses and systemd sockets. This was troublesome because it prevented the popular `localhost:1234` format. This commit changes the behavior so that if parsing of `std::net::SocketAddr` fails it attempts to parse it as `hostname:port`. `bind_*()` methods were also modified to be async because of this.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Fri, 27 Nov 2020 15:05:19 +0100
parents 0edcde404b02
children 27456533853e
files Cargo.toml README.md src/error.rs src/lib.rs src/resolv_addr.rs tests/ordinary.rs
diffstat 6 files changed, 171 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- a/Cargo.toml	Fri Nov 27 12:56:37 2020 +0100
+++ b/Cargo.toml	Fri Nov 27 15:05:19 2020 +0100
@@ -16,6 +16,6 @@
 parse_arg = { version = "0.1.4", optional = true }
 libsystemd = "0.2.1"
 lazy_static = "1.4.0"
-tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp"] }
+tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp", "dns"] }
 tokio_0_3 = { package = "tokio", version = "0.3", optional = true, features = ["net"] }
 async-std = { version = "1.7.0", optional = true }
--- a/README.md	Fri Nov 27 12:56:37 2020 +0100
+++ b/README.md	Fri Nov 27 15:05:19 2020 +0100
@@ -13,7 +13,8 @@
 Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
 You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
 
-Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and `async_std` sockets.
+Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are
+activated.
 
 ## Example
 
@@ -42,9 +43,9 @@
 
 * `serde` - implements `serde::Deserialize` for `SocketAddr`
 * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr`
-* `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr`
-* `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr`
-* `async_std` - adds `bind_async_std` convenience method to `SocketAddr`
+* `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr`
+* `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr`
+* `async_std` - adds `bind_async_std` method to `SocketAddr`
 
 ## MSRV
 
--- a/src/error.rs	Fri Nov 27 12:56:37 2020 +0100
+++ b/src/error.rs	Fri Nov 27 15:05:19 2020 +0100
@@ -18,7 +18,7 @@
 #[derive(Debug, Error)]
 pub(crate) enum ParseErrorInner {
     #[error("failed to parse socket address")]
-    SocketAddr(std::net::AddrParseError),
+    ResolvAddr(#[from] crate::resolv_addr::ResolvAddrError),
     #[error("invalid character '{c}' in systemd socket name {string} at position {pos}")]
     InvalidCharacter { string: String, c: char, pos: usize, },
     #[error("systemd socket name {string} is {len} characters long which is more than the limit 255")]
@@ -58,6 +58,8 @@
 pub(crate) enum BindErrorInner {
     #[error("failed to bind {addr}")]
     BindFailed { addr: std::net::SocketAddr, #[source] error: io::Error, },
+    #[error("failed to bind {addr}")]
+    BindOrResolvFailed { addr: crate::resolv_addr::ResolvAddr, #[source] error: io::Error, },
     #[error("failed to receive descriptors with names")]
     ReceiveDescriptors(#[source] crate::systemd_sockets::Error),
     #[error("missing systemd socket {0} - a typo or an attempt to bind twice")]
--- a/src/lib.rs	Fri Nov 27 12:56:37 2020 +0100
+++ b/src/lib.rs	Fri Nov 27 15:05:19 2020 +0100
@@ -11,8 +11,8 @@
 //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
 //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
 //!
-//! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and
-//! `async_std` sockets if the appropriate features are activated.
+//! Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are
+//! activated.
 //! 
 //! ## Example
 //! 
@@ -41,9 +41,9 @@
 //!
 //! * `serde` - implements `serde::Deserialize` for `SocketAddr`
 //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr`
-//! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr`
-//! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr`
-//! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr`
+//! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr`
+//! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr`
+//! * `async_std` - adds `bind_async_std` method to `SocketAddr`
 //!
 //! ## MSRV
 //!
@@ -54,11 +54,13 @@
 #![deny(missing_docs)]
 
 pub mod error;
+mod resolv_addr;
 
 use std::convert::{TryFrom, TryInto};
 use std::fmt;
 use std::ffi::{OsStr, OsString};
 use crate::error::*;
+use crate::resolv_addr::ResolvAddr;
 
 pub(crate) mod systemd_sockets {
     use std::fmt;
@@ -126,43 +128,89 @@
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
-        self._bind().map(|(socket, _)| socket)
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket),
+        }
     }
 
     /// Creates `tokio::net::TcpListener`
     ///
-    /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket.
+    /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket.
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     #[cfg(feature = "tokio_0_2")]
-    pub fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
-        let (socket, addr) = self._bind()?;
-        socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+    pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::Systemd(socket_name) => {
+                let (socket, addr) = Self::get_systemd(socket_name)?;
+                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+            },
+        }
     }
 
     /// Creates `tokio::net::TcpListener`
     ///
-    /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket.
+    /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket.
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     #[cfg(feature = "tokio_0_3")]
-    pub fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
-        let (socket, addr) = self._bind()?;
-        socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+    pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::Systemd(socket_name) => {
+                let (socket, addr) = Self::get_systemd(socket_name)?;
+                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+            },
+        }
     }
 
     /// Creates `async_std::net::TcpListener`
     ///
-    /// To be specific, it binds the socket and converts it to `async_std` socket.
+    /// To be specific, it binds the socket or converts systemd socket to `async_std` socket.
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     #[cfg(feature = "async-std")]
-    pub fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
-        let (socket, _) = self._bind()?;
-        Ok(socket.into())
+    pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::Systemd(socket_name) => {
+                let (socket, _) = Self::get_systemd(socket_name)?;
+                Ok(socket.into())
+            },
+        }
     }
 
     // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan
@@ -176,33 +224,28 @@
                 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()),
             }
         } else {
-            Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?)
+            match string.parse() {
+                Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
+                Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
+            }
         }
     }
 
-    fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
-        match self.0 {
-            SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
-                Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))),
-                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
-            },
-            SocketAddrInner::Systemd(socket_name) => {
-                use libsystemd::activation::IsType;
-                use std::os::unix::io::{FromRawFd, IntoRawFd};
+    fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
+        use libsystemd::activation::IsType;
+        use std::os::unix::io::{FromRawFd, IntoRawFd};
 
-                let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
-                // Safety: The environment variable is unset, so that no other calls can get the
-                // descriptors. The descriptors are taken from the map, not cloned, so they can't
-                // be duplicated.
-                unsafe {
-                    // match instead of combinators to avoid cloning socket_name
-                    match socket {
-                        Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
-                        Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
-                        None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
-                    }
-                }
-            },
+        let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
+        // Safety: The environment variable is unset, so that no other calls can get the
+        // descriptors. The descriptors are taken from the map, not cloned, so they can't
+        // be duplicated.
+        unsafe {
+            // match instead of combinators to avoid cloning socket_name
+            match socket {
+                Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
+                Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
+                None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
+            }
         }
     }
 }
@@ -222,6 +265,7 @@
         match self {
             SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
             SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
+            SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
         }
     }
 }
@@ -230,6 +274,7 @@
 #[derive(Debug, PartialEq)]
 enum SocketAddrInner {
     Ordinary(std::net::SocketAddr),
+    WithHostname(resolv_addr::ResolvAddr),
     Systemd(String),
 }
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/resolv_addr.rs	Fri Nov 27 15:05:19 2020 +0100
@@ -0,0 +1,75 @@
+use thiserror::Error;
+use std::fmt;
+
+#[derive(Debug, PartialEq)]
+pub(crate) struct ResolvAddr(String);
+
+impl ResolvAddr {
+    pub(crate) fn as_str(&self) -> &str {
+        &self.0
+    }
+
+    pub(crate) fn try_from_generic<T: std::ops::Deref<Target=str> + Into<String>>(string: T) -> Result<Self, ResolvAddrError> {
+        // can't use a combinator due to borrowing
+        let colon = match string.rfind(':') {
+            Some(colon) => colon,
+            None => return Err(ResolvAddrError::MissingPort(string.into())),
+        };
+
+        let (hostname, port) = string.split_at(colon);
+
+        if let Err(error) = port[1..].parse::<u16>() {
+            return Err(ResolvAddrError::InvalidPort { string: string.into(), error, });
+        }
+
+        let len = hostname.len();
+        if len > 253 {
+            return Err(ResolvAddrError::TooLong { string: string.into(), len, } )
+        }
+
+        let mut label_start = 0usize;
+
+        for (i, c) in hostname.chars().enumerate() {
+            match c {
+                '.' => {
+                    if i - label_start == 0 {
+                        return Err(ResolvAddrError::EmptyLabel { string: string.into(), label_start, });
+                    }
+
+                    label_start = i + 1;
+                },
+                'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => (),
+                _ => return Err(ResolvAddrError::InvalidCharacter { string: string.into(), c, pos: i, }),
+            }
+
+            if i - label_start > 63 {
+                return Err(ResolvAddrError::LongLabel { string: string.into(), label_start, label_end: i, });
+            }
+        }
+
+        Ok(ResolvAddr(string.into()))
+    }
+}
+
+impl fmt::Display for ResolvAddr {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        fmt::Display::fmt(&self.0, f)
+    }
+}
+
+
+#[derive(Debug, Error)]
+pub(crate) enum ResolvAddrError {
+    #[error("hostname {string} has {len} character which exceeds the limit of 253")]
+    TooLong { string: String, len: usize },
+    #[error("invalid character {c} in hostname {string} at position {pos}")]
+    InvalidCharacter { string: String, pos: usize, c: char, },
+    #[error("hostname {string} contains a label {} at position {label_start} which is {} characters long - more than the limit 63", &string[(*label_start)..(*label_end)], label_end - label_start)]
+    LongLabel { string: String, label_start: usize, label_end: usize, },
+    #[error("hostname {string} contains an empty label at position {label_start}")]
+    EmptyLabel { string: String, label_start: usize, },
+    #[error("the address {0} is missing a port")]
+    MissingPort(String),
+    #[error("failed to parse port numer in the address {string}")]
+    InvalidPort { string: String, error: std::num::ParseIntError, },
+}
--- a/tests/ordinary.rs	Fri Nov 27 12:56:37 2020 +0100
+++ b/tests/ordinary.rs	Fri Nov 27 15:05:19 2020 +0100
@@ -7,7 +7,7 @@
 enum Test {}
 
 impl comm::Test for Test {
-    const SOCKET_ADDR: &'static str = "127.0.0.1:4242";
+    const SOCKET_ADDR: &'static str = "localhost:4242";
 
     fn spawn_slave(program_name: &OsStr) -> io::Result<Child> {
         Command::new(program_name)