Mercurial > crates > systemd-socket
comparison src/lib.rs @ 17:dfb727367934
Allow specifying systemd name directly
When an application has systemd name by parsing a configuration directly
it can use the function added in this change to construct `SocketAddr`
without having to allocate an intermediate string.
author | Martin Habovstiak <martin.habovstiak@gmail.com> |
---|---|
date | Tue, 22 Dec 2020 14:15:49 +0100 |
parents | bc76507dd878 |
children | db1dc99252e2 |
comparison
equal
deleted
inserted
replaced
16:bc76507dd878 | 17:dfb727367934 |
---|---|
131 #[derive(Debug)] | 131 #[derive(Debug)] |
132 #[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))] | 132 #[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))] |
133 pub struct SocketAddr(SocketAddrInner); | 133 pub struct SocketAddr(SocketAddrInner); |
134 | 134 |
135 impl SocketAddr { | 135 impl SocketAddr { |
136 /// Creates SocketAddr from systemd name directly, without requiring `systemd://` prefix. | |
137 /// | |
138 /// Always fails with systemd unsupported error if systemd is not supported. | |
139 pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> { | |
140 Self::inner_from_systemd_name(name.into(), false) | |
141 } | |
142 | |
143 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | |
144 fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> { | |
145 let real_systemd_name = if prefixed { | |
146 &name[SYSTEMD_PREFIX.len()..] | |
147 } else { | |
148 &name | |
149 }; | |
150 | |
151 let name_len = real_systemd_name.len(); | |
152 match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { | |
153 None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))), | |
154 None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))), | |
155 None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()), | |
156 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()), | |
157 } | |
158 } | |
159 | |
160 | |
161 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] | |
162 fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> { | |
163 Err(ParseError(ParseErrorInner::SystemdUnsupported(name))) | |
164 } | |
165 | |
136 /// Creates `std::net::TcpListener` | 166 /// Creates `std::net::TcpListener` |
137 /// | 167 /// |
138 /// This method either `binds` the socket, if the address was provided or uses systemd socket | 168 /// This method either `binds` the socket, if the address was provided or uses systemd socket |
139 /// if the socket name was provided. | 169 /// if the socket name was provided. |
140 pub fn bind(self) -> Result<std::net::TcpListener, BindError> { | 170 pub fn bind(self) -> Result<std::net::TcpListener, BindError> { |
145 }, | 175 }, |
146 SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { | 176 SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) { |
147 Ok(socket) => Ok(socket), | 177 Ok(socket) => Ok(socket), |
148 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), | 178 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), |
149 }, | 179 }, |
150 SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name).map(|(socket, _)| socket), | 180 SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket), |
181 SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket), | |
151 } | 182 } |
152 } | 183 } |
153 | 184 |
154 /// Creates `tokio::net::TcpListener` | 185 /// Creates `tokio::net::TcpListener` |
155 /// | 186 /// |
167 SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { | 198 SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await { |
168 Ok(socket) => Ok(socket), | 199 Ok(socket) => Ok(socket), |
169 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), | 200 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), |
170 }, | 201 }, |
171 SocketAddrInner::Systemd(socket_name) => { | 202 SocketAddrInner::Systemd(socket_name) => { |
172 let (socket, addr) = Self::get_systemd(socket_name)?; | 203 let (socket, addr) = Self::get_systemd(socket_name, true)?; |
204 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) | |
205 }, | |
206 SocketAddrInner::SystemdNoPrefix(socket_name) => { | |
207 let (socket, addr) = Self::get_systemd(socket_name, false)?; | |
173 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) | 208 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) |
174 }, | 209 }, |
175 } | 210 } |
176 } | 211 } |
177 | 212 |
191 SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { | 226 SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await { |
192 Ok(socket) => Ok(socket), | 227 Ok(socket) => Ok(socket), |
193 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), | 228 Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())), |
194 }, | 229 }, |
195 SocketAddrInner::Systemd(socket_name) => { | 230 SocketAddrInner::Systemd(socket_name) => { |
196 let (socket, addr) = Self::get_systemd(socket_name)?; | 231 let (socket, addr) = Self::get_systemd(socket_name, true)?; |
232 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) | |
233 }, | |
234 SocketAddrInner::SystemdNoPrefix(socket_name) => { | |
235 let (socket, addr) = Self::get_systemd(socket_name, false)?; | |
197 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) | 236 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) |
198 }, | 237 }, |
199 } | 238 } |
200 } | 239 } |
201 | 240 |
215 SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { | 254 SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await { |
216 Ok(socket) => Ok(socket), | 255 Ok(socket) => Ok(socket), |
217 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), | 256 Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()), |
218 }, | 257 }, |
219 SocketAddrInner::Systemd(socket_name) => { | 258 SocketAddrInner::Systemd(socket_name) => { |
220 let (socket, _) = Self::get_systemd(socket_name)?; | 259 let (socket, _) = Self::get_systemd(socket_name, true)?; |
260 Ok(socket.into()) | |
261 }, | |
262 SocketAddrInner::SystemdNoPrefix(socket_name) => { | |
263 let (socket, _) = Self::get_systemd(socket_name, false)?; | |
221 Ok(socket.into()) | 264 Ok(socket.into()) |
222 }, | 265 }, |
223 } | 266 } |
224 } | 267 } |
225 | 268 |
226 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan | 269 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan |
227 // rules. | 270 // rules. |
228 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { | 271 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> { |
229 if string.starts_with(SYSTEMD_PREFIX) { | 272 if string.starts_with(SYSTEMD_PREFIX) { |
230 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 273 Self::inner_from_systemd_name(string.into(), true) |
231 { | |
232 let name_len = string.len() - SYSTEMD_PREFIX.len(); | |
233 match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { | |
234 None if name_len <= 255 => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))), | |
235 None => Err(ParseErrorInner::LongSocketName { string: string.into(), len: name_len }.into()), | |
236 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), | |
237 } | |
238 } | |
239 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] | |
240 { | |
241 Err(ParseErrorInner::SystemdUnsupported(string.into()).into()) | |
242 } | |
243 } else { | 274 } else { |
244 match string.parse() { | 275 match string.parse() { |
245 Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), | 276 Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))), |
246 Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), | 277 Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))), |
247 } | 278 } |
248 } | 279 } |
249 } | 280 } |
250 | 281 |
251 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 282 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
252 fn get_systemd(socket_name: String) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { | 283 fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { |
253 use libsystemd::activation::IsType; | 284 use libsystemd::activation::IsType; |
254 use std::os::unix::io::{FromRawFd, IntoRawFd}; | 285 use std::os::unix::io::{FromRawFd, IntoRawFd}; |
255 | 286 |
256 let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; | 287 let real_systemd_name = if prefixed { |
288 &socket_name[SYSTEMD_PREFIX.len()..] | |
289 } else { | |
290 &socket_name | |
291 }; | |
292 | |
293 let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?; | |
257 // Safety: The environment variable is unset, so that no other calls can get the | 294 // Safety: The environment variable is unset, so that no other calls can get the |
258 // descriptors. The descriptors are taken from the map, not cloned, so they can't | 295 // descriptors. The descriptors are taken from the map, not cloned, so they can't |
259 // be duplicated. | 296 // be duplicated. |
260 unsafe { | 297 unsafe { |
261 // match instead of combinators to avoid cloning socket_name | 298 // match instead of combinators to avoid cloning socket_name |
268 } | 305 } |
269 | 306 |
270 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it | 307 // This approach makes the rest of the code much simpler as it doesn't require sprinkling it |
271 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. | 308 // with #[cfg(all(target_os = "linux", feature = "enable_systemd"))] yet still statically guarantees it won't execute. |
272 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] | 309 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] |
273 fn get_systemd(socket_name: Never) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { | 310 fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { |
274 match socket_name {} | 311 match socket_name {} |
275 } | 312 } |
276 } | 313 } |
277 | 314 |
278 /// Displays the address in format that can be parsed again. | 315 /// Displays the address in format that can be parsed again. |
288 impl fmt::Display for SocketAddrInner { | 325 impl fmt::Display for SocketAddrInner { |
289 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | 326 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
290 match self { | 327 match self { |
291 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), | 328 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), |
292 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), | 329 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), |
330 SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr), | |
293 SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), | 331 SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f), |
294 } | 332 } |
295 } | 333 } |
296 } | 334 } |
297 | 335 |
303 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | 341 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] |
304 Systemd(String), | 342 Systemd(String), |
305 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] | 343 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] |
306 #[allow(dead_code)] | 344 #[allow(dead_code)] |
307 Systemd(Never), | 345 Systemd(Never), |
346 #[cfg(all(target_os = "linux", feature = "enable_systemd"))] | |
347 #[allow(dead_code)] | |
348 SystemdNoPrefix(String), | |
349 #[cfg(not(all(target_os = "linux", feature = "enable_systemd")))] | |
350 #[allow(dead_code)] | |
351 SystemdNoPrefix(Never), | |
308 } | 352 } |
309 | 353 |
310 const SYSTEMD_PREFIX: &str = "systemd://"; | 354 const SYSTEMD_PREFIX: &str = "systemd://"; |
311 | 355 |
312 impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr { | 356 impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr { |
437 #[test] | 481 #[test] |
438 #[should_panic] | 482 #[should_panic] |
439 fn parse_systemd_fail_too_long() { | 483 fn parse_systemd_fail_too_long() { |
440 "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap(); | 484 "systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap(); |
441 } | 485 } |
442 } | 486 |
487 #[test] | |
488 #[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)] | |
489 fn no_prefix_parse_systemd() { | |
490 SocketAddr::from_systemd_name("foo").unwrap(); | |
491 } | |
492 | |
493 #[test] | |
494 #[should_panic] | |
495 fn no_prefix_parse_systemd_fail_non_ascii() { | |
496 SocketAddr::from_systemd_name("fooĆ”").unwrap(); | |
497 } | |
498 | |
499 #[test] | |
500 #[should_panic] | |
501 fn no_prefix_parse_systemd_fail_too_long() { | |
502 SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap(); | |
503 } | |
504 } |