comparison src/lib.rs @ 0:a65053246c29

Initial commit
author Martin Habovstiak <martin.habovstiak@gmail.com>
date Thu, 26 Nov 2020 22:53:35 +0100
parents
children cabc4aafdd85
comparison
equal deleted inserted replaced
-1:000000000000 0:a65053246c29
1 //! A convenience crate for optionally supporting systemd socket activation.
2 //!
3 //! ## About
4 //!
5 //! The goal of this crate is to make socket activation with systemd in your project trivial.
6 //! It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std`
7 //! but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name.
8 //! Then it provides a method to bind the address which will return the socket from systemd if available.
9 //!
10 //! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag.
11 //! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format.
12 //! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding.
13 //!
14 //! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and
15 //! `async_std` sockets if the appropriate features are activated.
16 //!
17 //! ## Example
18 //!
19 //! ```no_run
20 //! use systemd_socket::SocketAddr;
21 //! use std::convert::TryFrom;
22 //! use std::io::Write;
23 //!
24 //! let mut args = std::env::args_os();
25 //! let program_name = args.next().expect("unknown program name");
26 //! let socket_addr = args.next().expect("missing socket address");
27 //! let socket_addr = SocketAddr::try_from(socket_addr).expect("failed to parse socket address");
28 //! let socket = socket_addr.bind().expect("failed to bind socket");
29 //!
30 //! loop {
31 //! let _ = socket
32 //! .accept()
33 //! .expect("failed to accept connection")
34 //! .0
35 //! .write_all(b"Hello world!")
36 //! .map_err(|err| eprintln!("Failed to send {}", err));
37 //! }
38 //! ```
39 //!
40 //! ## Features
41 //!
42 //! * `serde` - implements `serde::Deserialize` for `SocketAddr`
43 //! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr`
44 //! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr`
45 //! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr`
46 //! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr`
47
48 #![deny(missing_docs)]
49
50 pub mod error;
51
52 use std::convert::{TryFrom, TryInto};
53 use std::fmt;
54 use std::ffi::{OsStr, OsString};
55 use crate::error::*;
56
57 pub(crate) mod systemd_sockets {
58 use std::fmt;
59 use std::sync::Mutex;
60 use libsystemd::activation::FileDescriptor;
61 use libsystemd::errors::Error as LibSystemdError;
62 use libsystemd::errors::Result as LibSystemdResult;
63
64 #[derive(Debug)]
65 pub(crate) struct Error(&'static Mutex<LibSystemdError>);
66
67 impl fmt::Display for Error {
68 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
69 fmt::Display::fmt(&*self.0.lock().expect("mutex poisoned"), f)
70 }
71 }
72
73 // No source we can't keep the mutex locked
74 impl std::error::Error for Error {}
75
76 pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> {
77 match &*SYSTEMD_SOCKETS {
78 Ok(sockets) => Ok(sockets.take(name)),
79 Err(error) => Err(Error(error))
80 }
81 }
82
83 struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>);
84
85 impl SystemdSockets {
86 fn new() -> LibSystemdResult<Self> {
87 // MUST BE true FOR SAFETY!!!
88 let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ true)?.into_iter().map(|(fd, name)| (name, fd)).collect();
89 Ok(SystemdSockets(Mutex::new(map)))
90 }
91
92 fn take(&self, name: &str) -> Option<FileDescriptor> {
93 // MUST remove THE SOCKET FOR SAFETY!!!
94 self.0.lock().expect("poisoned mutex").remove(name)
95 }
96 }
97
98 lazy_static::lazy_static! {
99 // We don't panic in order to let the application handle the error later
100 static ref SYSTEMD_SOCKETS: Result<SystemdSockets, Mutex<LibSystemdError>> = SystemdSockets::new().map_err(Mutex::new);
101 }
102 }
103
104 /// Socket address that can be an ordinary address or a systemd socket
105 ///
106 /// This is the core type of this crate that abstracts possible addresses.
107 /// It can be (fallibly) converted from various types of strings or deserialized with `serde`.
108 /// After it's created, it can be bound as `TcpListener` from `std` or even `tokio` or `async_std`
109 /// if the appropriate feature is enabled.
110 ///
111 /// Optional dependencies on `parse_arg` and `serde` make it trivial to use with
112 /// [`configure_me`](https://crates.io/crates/configure_me).
113 #[derive(Debug)]
114 #[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))]
115 pub struct SocketAddr(SocketAddrInner);
116
117 impl SocketAddr {
118 /// Creates `std::net::TcpListener`
119 ///
120 /// This method either `binds` the socket, if the address was provided or uses systemd socket
121 /// if the socket name was provided.
122 pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
123 self._bind().map(|(socket, _)| socket)
124 }
125
126 /// Creates `tokio::net::TcpListener`
127 ///
128 /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket.
129 ///
130 /// This method either `binds` the socket, if the address was provided or uses systemd socket
131 /// if the socket name was provided.
132 #[cfg(feature = "tokio_0_2")]
133 pub fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
134 let (socket, addr) = self._bind()?;
135 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
136 }
137
138 /// Creates `tokio::net::TcpListener`
139 ///
140 /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket.
141 ///
142 /// This method either `binds` the socket, if the address was provided or uses systemd socket
143 /// if the socket name was provided.
144 #[cfg(feature = "tokio_0_3")]
145 pub fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
146 let (socket, addr) = self._bind()?;
147 socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
148 }
149
150 /// Creates `async_std::net::TcpListener`
151 ///
152 /// To be specific, it binds the socket and converts it to `async_std` socket.
153 ///
154 /// This method either `binds` the socket, if the address was provided or uses systemd socket
155 /// if the socket name was provided.
156 #[cfg(feature = "async-std")]
157 pub fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
158 let (socket, _) = self._bind()?;
159 Ok(socket.into())
160 }
161
162 // We can't impl<T: Deref<Target=str> + Into<String>> TryFrom<T> for SocketAddr because of orphan
163 // rules.
164 fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
165 if string.starts_with(SYSTEMD_PREFIX) {
166 match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
167 None => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))),
168 Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()),
169 }
170 } else {
171 Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?)
172 }
173 }
174
175 fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
176 match self.0 {
177 SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
178 Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))),
179 Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
180 },
181 SocketAddrInner::Systemd(socket_name) => {
182 use libsystemd::activation::IsType;
183 use std::os::unix::io::{FromRawFd, IntoRawFd};
184
185 let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?;
186 // Safety: The environment variable is unset, so that no other calls can get the
187 // descriptors. The descriptors are taken from the map, not cloned, so they can't
188 // be duplicated.
189 unsafe {
190 // match instead of combinators to avoid cloning socket_name
191 match socket {
192 Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
193 Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
194 None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
195 }
196 }
197 },
198 }
199 }
200 }
201
202 /// Displays the address in format that can be parsed again.
203 ///
204 /// **Important: While I don't expect this impl to change, don't rely on it!**
205 /// It should be used mostly for debugging/logging.
206 impl fmt::Display for SocketAddr {
207 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
208 fmt::Display::fmt(&self.0, f)
209 }
210 }
211
212 impl fmt::Display for SocketAddrInner {
213 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
214 match self {
215 SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
216 SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
217 }
218 }
219 }
220
221 // PartialEq for testing, I'm not convinced it should be exposed
222 #[derive(Debug, PartialEq)]
223 enum SocketAddrInner {
224 Ordinary(std::net::SocketAddr),
225 Systemd(String),
226 }
227
228 const SYSTEMD_PREFIX: &str = "systemd://";
229
230 impl std::str::FromStr for SocketAddr {
231 type Err = ParseError;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 SocketAddr::try_from_generic(s)
235 }
236 }
237
238 impl<'a> TryFrom<&'a str> for SocketAddr {
239 type Error = ParseError;
240
241 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
242 SocketAddr::try_from_generic(s)
243 }
244 }
245
246 impl TryFrom<String> for SocketAddr {
247 type Error = ParseError;
248
249 fn try_from(s: String) -> Result<Self, Self::Error> {
250 SocketAddr::try_from_generic(s)
251 }
252 }
253
254 impl<'a> TryFrom<&'a OsStr> for SocketAddr {
255 type Error = ParseOsStrError;
256
257 fn try_from(s: &'a OsStr) -> Result<Self, Self::Error> {
258 s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
259 }
260 }
261
262 impl TryFrom<OsString> for SocketAddr {
263 type Error = ParseOsStrError;
264
265 fn try_from(s: OsString) -> Result<Self, Self::Error> {
266 s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
267 }
268 }
269
270 #[cfg(feature = "serde")]
271 impl<'a> TryFrom<serde_str_helpers::DeserBorrowStr<'a>> for SocketAddr {
272 type Error = ParseError;
273
274 fn try_from(s: serde_str_helpers::DeserBorrowStr<'a>) -> Result<Self, Self::Error> {
275 SocketAddr::try_from_generic(std::borrow::Cow::from(s))
276 }
277 }
278
279 #[cfg(feature = "parse_arg")]
280 impl parse_arg::ParseArg for SocketAddr {
281 type Error = ParseOsStrError;
282
283 fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result {
284 std::net::SocketAddr::describe_type(&mut writer)?;
285 write!(writer, " or a systemd socket name prefixed with systemd://")
286 }
287
288 fn parse_arg(arg: &OsStr) -> Result<Self, Self::Error> {
289 arg.try_into()
290 }
291
292 fn parse_owned_arg(arg: OsString) -> Result<Self, Self::Error> {
293 arg.try_into()
294 }
295 }
296
297 #[cfg(test)]
298 mod tests {
299 use super::{SocketAddr, SocketAddrInner};
300
301 #[test]
302 fn parse_ordinary() {
303 assert_eq!("127.0.0.1:42".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into()));
304 }
305
306 #[test]
307 fn parse_systemd() {
308 assert_eq!("systemd://foo".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned()));
309 }
310 }