comparison src/lib.rs @ 17:dfb727367934

Allow specifying systemd name directly When an application has systemd name by parsing a configuration directly it can use the function added in this change to construct `SocketAddr` without having to allocate an intermediate string.
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Tue, 22 Dec 2020 14:15:49 +0100
parents bc76507dd878
children db1dc99252e2
comparison
equal deleted inserted replaced
16:bc76507dd878 17:dfb727367934
131 #[derive(Debug)] 131 #[derive(Debug)]
132 #[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))] 132 #[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))]
133 pub struct SocketAddr(SocketAddrInner); 133 pub struct SocketAddr(SocketAddrInner);
134 134
135 impl SocketAddr { 135 impl SocketAddr {
136 /// Creates SocketAddr from systemd name directly, without requiring `systemd://` prefix.
137 ///
138 /// Always fails with systemd unsupported error if systemd is not supported.
139 pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> {
140 Self::inner_from_systemd_name(name.into(), false)
141 }
142
143 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
144 fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> {
145 let real_systemd_name = if prefixed {
146 &name[SYSTEMD_PREFIX.len()..]
147 } else {
148 &name
149 };
150
151 let name_len = real_systemd_name.len();
152 match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
153 None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))),
154 None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))),
155 None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()),
156 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()),
157 }
158 }
159
160
161 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
162 fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> {
163 Err(ParseError(ParseErrorInner::SystemdUnsupported(name)))
164 }
165
136 /// Creates `std::net::TcpListener` 166 /// Creates `std::net::TcpListener`
137 /// 167 ///
138 /// This method either `binds` the socket, if the address was provided or uses systemd socket 168 /// This method either `binds` the socket, if the address was provided or uses systemd socket
139 /// if the socket name was provided. 169 /// if the socket name was provided.
140 pub fn bind(self) -> Result<std::net::TcpListener, BindError> { 170 pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
145 }, 175 },
146 SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { 176 SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
147 Ok(socket) => Ok(socket), 177 Ok(socket) => Ok(socket),
148 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), 178 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
149 }, 179 },
150 SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket), 180 SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket),
181 SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket),
151 } 182 }
152 } 183 }
153 184
154 /// Creates `tokio::net::TcpListener` 185 /// Creates `tokio::net::TcpListener`
155 /// 186 ///
167 SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { 198 SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
168 Ok(socket) => Ok(socket), 199 Ok(socket) => Ok(socket),
169 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), 200 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
170 }, 201 },
171 SocketAddrInner::Systemd(socket_name) => { 202 SocketAddrInner::Systemd(socket_name) => {
172 let (socket, addr) = Self::get_systemd(socket_name)?; 203 let (socket, addr) = Self::get_systemd(socket_name, true)?;
204 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
205 },
206 SocketAddrInner::SystemdNoPrefix(socket_name) => {
207 let (socket, addr) = Self::get_systemd(socket_name, false)?;
173 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) 208 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
174 }, 209 },
175 } 210 }
176 } 211 }
177 212
191 SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { 226 SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
192 Ok(socket) => Ok(socket), 227 Ok(socket) => Ok(socket),
193 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), 228 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
194 }, 229 },
195 SocketAddrInner::Systemd(socket_name) => { 230 SocketAddrInner::Systemd(socket_name) => {
196 let (socket, addr) = Self::get_systemd(socket_name)?; 231 let (socket, addr) = Self::get_systemd(socket_name, true)?;
232 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
233 },
234 SocketAddrInner::SystemdNoPrefix(socket_name) => {
235 let (socket, addr) = Self::get_systemd(socket_name, false)?;
197 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) 236 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
198 }, 237 },
199 } 238 }
200 } 239 }
201 240
215 SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { 254 SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
216 Ok(socket) => Ok(socket), 255 Ok(socket) => Ok(socket),
217 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), 256 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
218 }, 257 },
219 SocketAddrInner::Systemd(socket_name) => { 258 SocketAddrInner::Systemd(socket_name) => {
220 let (socket, _) = Self::get_systemd(socket_name)?; 259 let (socket, _) = Self::get_systemd(socket_name, true)?;
260 Ok(socket.into())
261 },
262 SocketAddrInner::SystemdNoPrefix(socket_name) => {
263 let (socket, _) = Self::get_systemd(socket_name, false)?;
221 Ok(socket.into()) 264 Ok(socket.into())
222 }, 265 },
223 } 266 }
224 } 267 }
225 268
226 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan 269 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan
227 // rules. 270 // rules.
228 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { 271 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
229 if string.starts_with(SYSTEMD_PREFIX) { 272 if string.starts_with(SYSTEMD_PREFIX) {
230 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] 273 Self::inner_from_systemd_name(string.into(), true)
231 {
232 let name_len = string.len() - SYSTEMD_PREFIX.len();
233 match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
234 None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))),
235 None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()),
236 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()),
237 }
238 }
239 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
240 {
241 Err(ParseErrorInner::SystemdUnsupported(string.into()).into())
242 }
243 } else { 274 } else {
244 match string.parse() { 275 match string.parse() {
245 Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), 276 Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
246 Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), 277 Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
247 } 278 }
248 } 279 }
249 } 280 }
250 281
251 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] 282 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
252 fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { 283 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
253 use libsystemd::activation::IsType; 284 use libsystemd::activation::IsType;
254 use std::os::unix::io::{FromRawFd, IntoRawFd}; 285 use std::os::unix::io::{FromRawFd, IntoRawFd};
255 286
256 let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; 287 let real_systemd_name = if prefixed {
288 &socket_name[SYSTEMD_PREFIX.len()..]
289 } else {
290 &socket_name
291 };
292
293 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
257 // Safety: The environment variable is unset, so that no other calls can get the 294 // Safety: The environment variable is unset, so that no other calls can get the
258 // descriptors. The descriptors are taken from the map, not cloned, so they can't 295 // descriptors. The descriptors are taken from the map, not cloned, so they can't
259 // be duplicated. 296 // be duplicated.
260 unsafe { 297 unsafe {
261 // match instead of combinators to avoid cloning socket_name 298 // match instead of combinators to avoid cloning socket_name
268 } 305 }
269 306
270 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it 307 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it
271 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. 308 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute.
272 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] 309 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
273 fn get_systemd(socket_name: Never) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { 310 fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
274 match socket_name {} 311 match socket_name {}
275 } 312 }
276 } 313 }
277 314
278 /// Displays the address in format that can be parsed again. 315 /// Displays the address in format that can be parsed again.
288 impl fmt::Display for SocketAddrInner { 325 impl fmt::Display for SocketAddrInner {
289 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 326 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
290 match self { 327 match self {
291 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), 328 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
292 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), 329 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
330 SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr),
293 SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), 331 SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
294 } 332 }
295 } 333 }
296 } 334 }
297 335
303 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] 341 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
304 Systemd(String), 342 Systemd(String),
305 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] 343 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
306 #[allow(dead_code)] 344 #[allow(dead_code)]
307 Systemd(Never), 345 Systemd(Never),
346 #[cfg(all(target_os = "linux", feature = "enable_systemd"))]
347 #[allow(dead_code)]
348 SystemdNoPrefix(String),
349 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
350 #[allow(dead_code)]
351 SystemdNoPrefix(Never),
308 } 352 }
309 353
310 const SYSTEMD_PREFIX: &str = "systemd://"; 354 const SYSTEMD_PREFIX: &str = "systemd://";
311 355
312 impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr { 356 impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr {
437 #[test] 481 #[test]
438 #[should_panic] 482 #[should_panic]
439 fn parse_systemd_fail_too_long() { 483 fn parse_systemd_fail_too_long() {
440 "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap(); 484 "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap();
441 } 485 }
442 } 486
487 #[test]
488 #[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)]
489 fn no_prefix_parse_systemd() {
490 SocketAddr::from_systemd_name("foo").unwrap();
491 }
492
493 #[test]
494 #[should_panic]
495 fn no_prefix_parse_systemd_fail_non_ascii() {
496 SocketAddr::from_systemd_name("fooĆ”").unwrap();
497 }
498
499 #[test]
500 #[should_panic]
501 fn no_prefix_parse_systemd_fail_too_long() {
502 SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap();
503 }
504 }