diff src/lib.rs @ 4:66c0e10c89fc

Support resolving hostnames Until now the crate supported only IP addresses and systemd sockets. This was troublesome because it prevented the popular `localhost:1234` format. This commit changes the behavior so that if parsing of `std::net::SocketAddr` fails it attempts to parse it as `hostname:port`. `bind_*()` methods were also modified to be async because of this.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Fri, 27 Nov 2020 15:05:19 +0100
parents 0edcde404b02
children a7893294e9b2
line wrap: on
line diff
--- a/src/lib.rs	Fri Nov 27 12:56:37 2020 +0100
+++ b/src/lib.rs	Fri Nov 27 15:05:19 2020 +0100
@@ -11,8 +11,8 @@
 //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
 //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
 //!
-//! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and
-//! `async_std` sockets if the appropriate features are activated.
+//! Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are
+//! activated.
 //! 
 //! ## Example
 //! 
@@ -41,9 +41,9 @@
 //!
 //! * `serde` - implements `serde::Deserialize` for `SocketAddr`
 //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr`
-//! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr`
-//! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr`
-//! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr`
+//! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr`
+//! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr`
+//! * `async_std` - adds `bind_async_std` method to `SocketAddr`
 //!
 //! ## MSRV
 //!
@@ -54,11 +54,13 @@
 #![deny(missing_docs)]
 
 pub mod error;
+mod resolv_addr;
 
 use std::convert::{TryFrom, TryInto};
 use std::fmt;
 use std::ffi::{OsStr, OsString};
 use crate::error::*;
+use crate::resolv_addr::ResolvAddr;
 
 pub(crate) mod systemd_sockets {
     use std::fmt;
@@ -126,43 +128,89 @@
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
-        self._bind().map(|(socket, _)| socket)
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket),
+        }
     }
 
     /// Creates `tokio::net::TcpListener`
     ///
-    /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket.
+    /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket.
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     #[cfg(feature = "tokio_0_2")]
-    pub fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
-        let (socket, addr) = self._bind()?;
-        socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+    pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::Systemd(socket_name) => {
+                let (socket, addr) = Self::get_systemd(socket_name)?;
+                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+            },
+        }
     }
 
     /// Creates `tokio::net::TcpListener`
     ///
-    /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket.
+    /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket.
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     #[cfg(feature = "tokio_0_3")]
-    pub fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
-        let (socket, addr) = self._bind()?;
-        socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+    pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
+            },
+            SocketAddrInner::Systemd(socket_name) => {
+                let (socket, addr) = Self::get_systemd(socket_name)?;
+                socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
+            },
+        }
     }
 
     /// Creates `async_std::net::TcpListener`
     ///
-    /// To be specific, it binds the socket and converts it to `async_std` socket.
+    /// To be specific, it binds the socket or converts systemd socket to `async_std` socket.
     ///
     /// This method either `binds` the socket, if the address was provided or uses systemd socket
     /// if the socket name was provided.
     #[cfg(feature = "async-std")]
-    pub fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
-        let (socket, _) = self._bind()?;
-        Ok(socket.into())
+    pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
+        match self.0 {
+            SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
+                Ok(socket) => Ok(socket),
+                Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
+            },
+            SocketAddrInner::Systemd(socket_name) => {
+                let (socket, _) = Self::get_systemd(socket_name)?;
+                Ok(socket.into())
+            },
+        }
     }
 
     // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan
@@ -176,33 +224,28 @@
                 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()),
             }
         } else {
-            Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?)
+            match string.parse() {
+                Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
+                Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
+            }
         }
     }
 
-    fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
-        match self.0 {
-            SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
-                Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))),
-                Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
-            },
-            SocketAddrInner::Systemd(socket_name) => {
-                use libsystemd::activation::IsType;
-                use std::os::unix::io::{FromRawFd, IntoRawFd};
+    fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
+        use libsystemd::activation::IsType;
+        use std::os::unix::io::{FromRawFd, IntoRawFd};
 
-                let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
-                // Safety: The environment variable is unset, so that no other calls can get the
-                // descriptors. The descriptors are taken from the map, not cloned, so they can't
-                // be duplicated.
-                unsafe {
-                    // match instead of combinators to avoid cloning socket_name
-                    match socket {
-                        Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
-                        Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
-                        None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
-                    }
-                }
-            },
+        let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
+        // Safety: The environment variable is unset, so that no other calls can get the
+        // descriptors. The descriptors are taken from the map, not cloned, so they can't
+        // be duplicated.
+        unsafe {
+            // match instead of combinators to avoid cloning socket_name
+            match socket {
+                Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
+                Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
+                None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
+            }
         }
     }
 }
@@ -222,6 +265,7 @@
         match self {
             SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
             SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
+            SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
         }
     }
 }
@@ -230,6 +274,7 @@
 #[derive(Debug, PartialEq)]
 enum SocketAddrInner {
     Ordinary(std::net::SocketAddr),
+    WithHostname(resolv_addr::ResolvAddr),
     Systemd(String),
 }