comparison src/lib.rs @ 4:66c0e10c89fc

Support resolving hostnames Until now the crate supported only IP addresses and systemd sockets. This was troublesome because it prevented the popular `localhost:1234` format. This commit changes the behavior so that if parsing of `std::net::SocketAddr` fails it attempts to parse it as `hostname:port`. `bind_*()` methods were also modified to be async because of this.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Fri, 27 Nov 2020 15:05:19 +0100
parents 0edcde404b02
children a7893294e9b2
comparison
equal deleted inserted replaced
3:0edcde404b02 4:66c0e10c89fc
9 //! 9 //!
10 //! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag. 10 //! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag.
11 //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. 11 //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
12 //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. 12 //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
13 //! 13 //!
14 //! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and 14 //! Further, the crate also provides methods for binding `tokio` 0.2, 0.3, and `async_std` sockets if the appropriate features are
15 //! `async_std` sockets if the appropriate features are activated. 15 //! activated.
16 //! 16 //!
17 //! ## Example 17 //! ## Example
18 //! 18 //!
19 //! ```no_run 19 //! ```no_run
20 //! use systemd_socket::SocketAddr; 20 //! use systemd_socket::SocketAddr;
39 //! 39 //!
40 //! ## Features 40 //! ## Features
41 //! 41 //!
42 //! * `serde` - implements `serde::Deserialize` for `SocketAddr` 42 //! * `serde` - implements `serde::Deserialize` for `SocketAddr`
43 //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` 43 //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr`
44 //! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr` 44 //! * `tokio_0_2` - adds `bind_tokio_0_2` method to `SocketAddr`
45 //! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr` 45 //! * `tokio_0_3` - adds `bind_tokio_0_3` method to `SocketAddr`
46 //! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr` 46 //! * `async_std` - adds `bind_async_std` method to `SocketAddr`
47 //! 47 //!
48 //! ## MSRV 48 //! ## MSRV
49 //! 49 //!
50 //! This crate must always compile with the latest Rust available in the latest Debian stable. 50 //! This crate must always compile with the latest Rust available in the latest Debian stable.
51 //! That is currently Rust 1.41.1. (Debian 10 - Buster) 51 //! That is currently Rust 1.41.1. (Debian 10 - Buster)
52 52
53 53
54 #![deny(missing_docs)] 54 #![deny(missing_docs)]
55 55
56 pub mod error; 56 pub mod error;
57 mod resolv_addr;
57 58
58 use std::convert::{TryFrom, TryInto}; 59 use std::convert::{TryFrom, TryInto};
59 use std::fmt; 60 use std::fmt;
60 use std::ffi::{OsStr, OsString}; 61 use std::ffi::{OsStr, OsString};
61 use crate::error::*; 62 use crate::error::*;
63 use crate::resolv_addr::ResolvAddr;
62 64
63 pub(crate) mod systemd_sockets { 65 pub(crate) mod systemd_sockets {
64 use std::fmt; 66 use std::fmt;
65 use std::sync::Mutex; 67 use std::sync::Mutex;
66 use libsystemd::activation::FileDescriptor; 68 use libsystemd::activation::FileDescriptor;
124 /// Creates `std::net::TcpListener` 126 /// Creates `std::net::TcpListener`
125 /// 127 ///
126 /// This method either `binds` the socket, if the address was provided or uses systemd socket 128 /// This method either `binds` the socket, if the address was provided or uses systemd socket
127 /// if the socket name was provided. 129 /// if the socket name was provided.
128 pub fn bind(self) -> Result<std::net::TcpListener, BindError> { 130 pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
129 self._bind().map(|(socket, _)| socket) 131 match self.0 {
132 SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
133 Ok(socket) => Ok(socket),
134 Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
135 },
136 SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
137 Ok(socket) => Ok(socket),
138 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
139 },
140 SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket),
141 }
130 } 142 }
131 143
132 /// Creates `tokio::net::TcpListener` 144 /// Creates `tokio::net::TcpListener`
133 /// 145 ///
134 /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket. 146 /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.2 socket.
135 /// 147 ///
136 /// This method either `binds` the socket, if the address was provided or uses systemd socket 148 /// This method either `binds` the socket, if the address was provided or uses systemd socket
137 /// if the socket name was provided. 149 /// if the socket name was provided.
138 #[cfg(feature = "tokio_0_2")] 150 #[cfg(feature = "tokio_0_2")]
139 pub fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> { 151 pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
140 let (socket, addr) = self._bind()?; 152 match self.0 {
141 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) 153 SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await {
154 Ok(socket) => Ok(socket),
155 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
156 },
157 SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
158 Ok(socket) => Ok(socket),
159 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
160 },
161 SocketAddrInner::Systemd(socket_name) => {
162 let (socket, addr) = Self::get_systemd(socket_name)?;
163 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
164 },
165 }
142 } 166 }
143 167
144 /// Creates `tokio::net::TcpListener` 168 /// Creates `tokio::net::TcpListener`
145 /// 169 ///
146 /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket. 170 /// To be specific, it binds the socket or converts systemd socket to `tokio` 0.3 socket.
147 /// 171 ///
148 /// This method either `binds` the socket, if the address was provided or uses systemd socket 172 /// This method either `binds` the socket, if the address was provided or uses systemd socket
149 /// if the socket name was provided. 173 /// if the socket name was provided.
150 #[cfg(feature = "tokio_0_3")] 174 #[cfg(feature = "tokio_0_3")]
151 pub fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> { 175 pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
152 let (socket, addr) = self._bind()?; 176 match self.0 {
153 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) 177 SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await {
178 Ok(socket) => Ok(socket),
179 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
180 },
181 SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
182 Ok(socket) => Ok(socket),
183 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
184 },
185 SocketAddrInner::Systemd(socket_name) => {
186 let (socket, addr) = Self::get_systemd(socket_name)?;
187 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
188 },
189 }
154 } 190 }
155 191
156 /// Creates `async_std::net::TcpListener` 192 /// Creates `async_std::net::TcpListener`
157 /// 193 ///
158 /// To be specific, it binds the socket and converts it to `async_std` socket. 194 /// To be specific, it binds the socket or converts systemd socket to `async_std` socket.
159 /// 195 ///
160 /// This method either `binds` the socket, if the address was provided or uses systemd socket 196 /// This method either `binds` the socket, if the address was provided or uses systemd socket
161 /// if the socket name was provided. 197 /// if the socket name was provided.
162 #[cfg(feature = "async-std")] 198 #[cfg(feature = "async-std")]
163 pub fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> { 199 pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
164 let (socket, _) = self._bind()?; 200 match self.0 {
165 Ok(socket.into()) 201 SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await {
202 Ok(socket) => Ok(socket),
203 Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
204 },
205 SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
206 Ok(socket) => Ok(socket),
207 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
208 },
209 SocketAddrInner::Systemd(socket_name) => {
210 let (socket, _) = Self::get_systemd(socket_name)?;
211 Ok(socket.into())
212 },
213 }
166 } 214 }
167 215
168 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan 216 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan
169 // rules. 217 // rules.
170 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { 218 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
174 None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))), 222 None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))),
175 None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()), 223 None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()),
176 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), 224 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()),
177 } 225 }
178 } else { 226 } else {
179 Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?) 227 match string.parse() {
180 } 228 Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
181 } 229 Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
182 230 }
183 fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { 231 }
184 match self.0 { 232 }
185 SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { 233
186 Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))), 234 fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
187 Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), 235 use libsystemd::activation::IsType;
188 }, 236 use std::os::unix::io::{FromRawFd, IntoRawFd};
189 SocketAddrInner::Systemd(socket_name) => { 237
190 use libsystemd::activation::IsType; 238 let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
191 use std::os::unix::io::{FromRawFd, IntoRawFd}; 239 // Safety: The environment variable is unset, so that no other calls can get the
192 240 // descriptors. The descriptors are taken from the map, not cloned, so they can't
193 let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; 241 // be duplicated.
194 // Safety: The environment variable is unset, so that no other calls can get the 242 unsafe {
195 // descriptors. The descriptors are taken from the map, not cloned, so they can't 243 // match instead of combinators to avoid cloning socket_name
196 // be duplicated. 244 match socket {
197 unsafe { 245 Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
198 // match instead of combinators to avoid cloning socket_name 246 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
199 match socket { 247 None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
200 Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), 248 }
201 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
202 None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
203 }
204 }
205 },
206 } 249 }
207 } 250 }
208 } 251 }
209 252
210 /// Displays the address in format that can be parsed again. 253 /// Displays the address in format that can be parsed again.
220 impl fmt::Display for SocketAddrInner { 263 impl fmt::Display for SocketAddrInner {
221 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 264 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
222 match self { 265 match self {
223 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), 266 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
224 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), 267 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
268 SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
225 } 269 }
226 } 270 }
227 } 271 }
228 272
229 // PartialEq for testing, I'm not convinced it should be exposed 273 // PartialEq for testing, I'm not convinced it should be exposed
230 #[derive(Debug, PartialEq)] 274 #[derive(Debug, PartialEq)]
231 enum SocketAddrInner { 275 enum SocketAddrInner {
232 Ordinary(std::net::SocketAddr), 276 Ordinary(std::net::SocketAddr),
277 WithHostname(resolv_addr::ResolvAddr),
233 Systemd(String), 278 Systemd(String),
234 } 279 }
235 280
236 const SYSTEMD_PREFIX: &str = "systemd://"; 281 const SYSTEMD_PREFIX: &str = "systemd://";
237 282