Mercurial > crates > systemd-socket
diff src/lib.rs @ 4:66c0e10c89fc
Support resolving hostnames
Until now the crate supported only IP addresses and systemd sockets.
This was troublesome because it prevented the popular `localhost:1234`
format. This commit changes the behavior so that if parsing of
`std::net::SocketAddr` fails it attempts to parse it as `hostname:port`.
`bind_*()` methods were also modified to be async because of this.
author | Martin Habovstiak <martin.habovstiak@gmail.com> |
---|---|
date | Fri, 27 Nov 2020 15:05:19 +0100 |
parents | 0edcde404b02 |
children | a7893294e9b2 |
line wrap: on
line diff
--- a/src/lib.rs Fri Nov 27 12:56:37 2020 +0100 +++ b/src/lib.rs Fri Nov 27 15:05:19 2020 +0100 @@ -11,8 +11,8 @@ //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. //! -//! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and -//! `async_std` sockets if the appropriate features are activated. +//! Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are +//! activated. //! //! ## Example //! @@ -41,9 +41,9 @@ //! //! * `serde` - implements `serde::Deserialize` for `SocketAddr` //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` -//! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr` -//! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr` -//! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr` +//! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr` +//! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr` +//! * `async_std` - adds `bind_async_std` method to `SocketAddr` //! //! ## MSRV //! @@ -54,11 +54,13 @@ #![deny(missing_docs)] pub mod error; +mod resolv_addr; use std::convert::{TryFrom, TryInto}; use std::fmt; use std::ffi::{OsStr, OsString}; use crate::error::*; +use crate::resolv_addr::ResolvAddr; pub(crate) mod systemd_sockets { use std::fmt; @@ -126,43 +128,89 @@ /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. pub fn bind(self) -> Result<std::net::TcpListener, BindError> { - self._bind().map(|(socket, _)| socket) + match self.0 { + SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), + }, + SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), + }, + SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket), + } } /// Creates `tokio::net::TcpListener` /// - /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket. + /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "tokio_0_2")] - pub fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { - let (socket, addr) = self._bind()?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), + }, + SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), + }, + SocketAddrInner::Systemd(socket_name) => { + let (socket, addr) = Self::get_systemd(socket_name)?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + }, + } } /// Creates `tokio::net::TcpListener` /// - /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket. + /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "tokio_0_3")] - pub fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { - let (socket, addr) = self._bind()?; - socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())), + }, + SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), + }, + SocketAddrInner::Systemd(socket_name) => { + let (socket, addr) = Self::get_systemd(socket_name)?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + }, + } } /// Creates `async_std::net::TcpListener` /// - /// To be specific, it binds the socket and converts it to `async_std` socket. + /// To be specific, it binds the socket or converts systemd socket to `async_std` socket. /// /// This method either `binds` the socket, if the address was provided or uses systemd socket /// if the socket name was provided. #[cfg(feature = "async-std")] - pub fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { - let (socket, _) = self._bind()?; - Ok(socket.into()) + pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), + }, + SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { + Ok(socket) => Ok(socket), + Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), + }, + SocketAddrInner::Systemd(socket_name) => { + let (socket, _) = Self::get_systemd(socket_name)?; + Ok(socket.into()) + }, + } } // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan @@ -176,33 +224,28 @@ Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), } } else { - Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?) + match string.parse() { + Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), + Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), + } } } - fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { - match self.0 { - SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { - Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))), - Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), - }, - SocketAddrInner::Systemd(socket_name) => { - use libsystemd::activation::IsType; - use std::os::unix::io::{FromRawFd, IntoRawFd}; + fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + use libsystemd::activation::IsType; + use std::os::unix::io::{FromRawFd, IntoRawFd}; - let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; - // Safety: The environment variable is unset, so that no other calls can get the - // descriptors. The descriptors are taken from the map, not cloned, so they can't - // be duplicated. - unsafe { - // match instead of combinators to avoid cloning socket_name - match socket { - Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), - Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), - None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) - } - } - }, + let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; + // Safety: The environment variable is unset, so that no other calls can get the + // descriptors. The descriptors are taken from the map, not cloned, so they can't + // be duplicated. + unsafe { + // match instead of combinators to avoid cloning socket_name + match socket { + Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), + Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), + None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) + } } } } @@ -222,6 +265,7 @@ match self { SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), + SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), } } } @@ -230,6 +274,7 @@ #[derive(Debug, PartialEq)] enum SocketAddrInner { Ordinary(std::net::SocketAddr), + WithHostname(resolv_addr::ResolvAddr), Systemd(String), }