# HG changeset patch # User Martin Habovstiak # Date 1606427615 -3600 # Node ID a65053246c29aa32c305001516249bfd6da8af55 Initial commit diff -r 000000000000 -r a65053246c29 .gitignore --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/.gitignore Thu Nov 26 22:53:35 2020 +0100 @@ -0,0 +1,2 @@ +/target +Cargo.lock diff -r 000000000000 -r a65053246c29 Cargo.toml --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Cargo.toml Thu Nov 26 22:53:35 2020 +0100 @@ -0,0 +1,21 @@ +[package] +name = "systemd_socket" +version = "0.1.0" +authors = ["Martin Habovstiak "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +serde = ["serde_crate", "serde_str_helpers"] + +[dependencies] +thiserror = "1.0.22" +serde_crate = { package = "serde", version = "1.0.117", optional = true, features = ["derive"] } +serde_str_helpers = { version = "0.1.0", optional = true } +parse_arg = { version = "0.1.4", optional = true } +libsystemd = "0.2.1" +lazy_static = "1.4.0" +tokio_0_2 = { package = "tokio", version = "0.2", optional = true, features = ["tcp"] } +tokio_0_3 = { package = "tokio", version = "0.3", optional = true, features = ["net"] } +async-std = { version = "1.7.0", optional = true } diff -r 000000000000 -r a65053246c29 README.md --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/README.md Thu Nov 26 22:53:35 2020 +0100 @@ -0,0 +1,51 @@ +# systemd socket + +A convenience crate for optionally supporting systemd socket activation. + +## About + +The goal of this crate is to make socket activation with systemd in your project trivial. +It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std` +but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name. +Then it provides a method to bind the address which will return the socket from systemd if available. + +The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag. +Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. +You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. + +Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and `async_std` sockets. + +## Example + +```rust +use systemd_socket::SocketAddr; +use std::convert::TryFrom; +use std::io::Write; + +let mut args = std::env::args_os(); +let program_name = args.next().expect("unknown program name"); +let socket_addr = args.next().expect("missing socket address"); +let socket_addr = SocketAddr::try_from(socket_addr).expect("failed to parse socket address"); +let socket = socket_addr.bind().expect("failed to bind socket"); + +loop { + let _ = socket + .accept() + .expect("failed to accept connection") + .0 + .write_all(b"Hello world!") + .map_err(|err| eprintln!("Failed to send {}", err)); +} +``` + +## Features + +* `serde` - implements `serde::Deserialize` for `SocketAddr` +* `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` +* `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr` +* `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr` +* `async_std` - adds `bind_async_std` convenience method to `SocketAddr` + +## License + +MITNFA diff -r 000000000000 -r a65053246c29 src/error.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/error.rs Thu Nov 26 22:53:35 2020 +0100 @@ -0,0 +1,94 @@ +//! Error types that can occur when dealing with `SocketAddr` +//! +//! This module separates the error types from root module to avoid clutter. + + +use thiserror::Error; +use std::io; + +/// Error that can occur during parsing of `SocketAddr` from a string +/// +/// This encapsulates possible errors that can occur when parsing the input. +/// It is currently opaque because the representation is not certain yet. +/// It can be displayed using the standard `Error` trait. +#[derive(Debug, Error)] +#[error(transparent)] +pub struct ParseError(#[from] pub(crate) ParseErrorInner); + +#[derive(Debug, Error)] +pub(crate) enum ParseErrorInner { + #[error("failed to parse socket address")] + SocketAddr(std::net::AddrParseError), + #[error("invalid character '{c}' in systemd socket name {string} at position {pos}")] + InvalidCharacter { string: String, c: char, pos: usize, } +} + +/// Error that can occur during parsing of `SocketAddr` from a `OsStr`/`OsString` +/// +/// As opposed to parsing from `&str` or `String`, parsing from `OsStr` can fail due to one more +/// reason: invalid UTF-8. This error type expresses this possibility and is returned whenever such +/// conversion is attempted. It is not opaque because the possible variants are pretty much +/// certain, but it may contain `ParseError` which is opaque. +/// +/// This error can be displayed using standard `Error` trait. +/// See `ParseError` for more information. +#[derive(Debug, Error)] +pub enum ParseOsStrError { + /// The input was not a valid UTF-8 string + #[error("the address is not a valid UTF-8 string")] + InvalidUtf8, + /// The input was a valid UTF-8 string but the address was invalid + #[error(transparent)] + InvalidAddress(#[from] ParseError), +} + +/// Error that can occur during binding of a socket +/// +/// This encapsulates possible errors that can occur when binding a socket or receiving a socket +/// from systemd. +/// It is currently opaque because the representation is not certain yet. +/// It can be displayed using the standard `Error` trait. +#[derive(Debug, Error)] +#[error(transparent)] +pub struct BindError(#[from] pub(crate) BindErrorInner); + +#[derive(Debug, Error)] +pub(crate) enum BindErrorInner { + #[error("failed to bind {addr}")] + BindFailed { addr: std::net::SocketAddr, #[source] error: io::Error, }, + #[error("failed to receive descriptors with names")] + ReceiveDescriptors(#[source] crate::systemd_sockets::Error), + #[error("missing systemd socket {0} - a typo or an attempt to bind twice")] + MissingDescriptor(String), + #[error("the systemd socket {0} is not an internet socket")] + NotInetSocket(String), +} + +/// Error that can happen when binding Tokio socket. +/// +/// As opposed to `std` and `async_std` sockets, tokio sockets can fail to convert. +/// This error type expresses this possibility. +#[cfg(any(feature = "tokio_0_2", feature = "tokio_0_3"))] +#[derive(Debug, Error)] +#[error(transparent)] +pub enum TokioBindError { + /// Either binding of socket or receiving systemd socket failed + Bind(#[from] BindError), + /// Conversion from std `std::net::TcpListener` to `tokio::net::TcpListener` failed + Convert(#[from] TokioConversionError), +} + +/// Error that can happen when converting Tokio socket. +/// +/// As opposed to `std` and `async_std` sockets, tokio sockets can fail to convert. +/// This error type encapsulates conversion error together with additional information so that it +/// can be displayed nicely. The encapsulation also allows for future-proofing. +#[cfg(any(feature = "tokio_0_2", feature = "tokio_0_3"))] +#[derive(Debug, Error)] +#[error("failed to convert std socket {addr} into tokio socket")] +pub struct TokioConversionError { + pub(crate) addr: crate::SocketAddrInner, + #[source] + pub(crate) error: io::Error, +} + diff -r 000000000000 -r a65053246c29 src/lib.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/lib.rs Thu Nov 26 22:53:35 2020 +0100 @@ -0,0 +1,310 @@ +//! A convenience crate for optionally supporting systemd socket activation. +//! +//! ## About +//! +//! The goal of this crate is to make socket activation with systemd in your project trivial. +//! It provides a replacement for `std::net::SocketAddr` that allows parsing the bind address from string just like the one from `std` +//! but on top of that also allows `systemd://socket_name` format that tells it to use systemd activation with given socket name. +//! Then it provides a method to bind the address which will return the socket from systemd if available. +//! +//! The provided type supports conversions from various types of strings and also `serde` and `parse_arg` via feature flag. +//! Thanks to this the change to your code should be minimal - parsing will continue to work, it'll just allow a new format. +//! You only need to change the code to use `SocketAddr::bind()` instead of `TcpListener::bind()` for binding. +//! +//! Further, the crate also provides convenience methods for binding `tokio` 0.2, 0.3, and +//! `async_std` sockets if the appropriate features are activated. +//! +//! ## Example +//! +//! ```no_run +//! use systemd_socket::SocketAddr; +//! use std::convert::TryFrom; +//! use std::io::Write; +//! +//! let mut args = std::env::args_os(); +//! let program_name = args.next().expect("unknown program name"); +//! let socket_addr = args.next().expect("missing socket address"); +//! let socket_addr = SocketAddr::try_from(socket_addr).expect("failed to parse socket address"); +//! let socket = socket_addr.bind().expect("failed to bind socket"); +//! +//! loop { +//! let _ = socket +//! .accept() +//! .expect("failed to accept connection") +//! .0 +//! .write_all(b"Hello world!") +//! .map_err(|err| eprintln!("Failed to send {}", err)); +//! } +//! ``` +//! +//! ## Features +//! +//! * `serde` - implements `serde::Deserialize` for `SocketAddr` +//! * `parse_arg` - implements `parse_arg::ParseArg` for `SocketAddr` +//! * `tokio_0_2` - adds `bind_tokio_0_2` convenience method to `SocketAddr` +//! * `tokio_0_3` - adds `bind_tokio_0_3` convenience method to `SocketAddr` +//! * `async_std` - adds `bind_async_std` convenience method to `SocketAddr` + +#![deny(missing_docs)] + +pub mod error; + +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::ffi::{OsStr, OsString}; +use crate::error::*; + +pub(crate) mod systemd_sockets { + use std::fmt; + use std::sync::Mutex; + use libsystemd::activation::FileDescriptor; + use libsystemd::errors::Error as LibSystemdError; + use libsystemd::errors::Result as LibSystemdResult; + + #[derive(Debug)] + pub(crate) struct Error(&'static Mutex); + + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&*self.0.lock().expect("mutex poisoned"), f) + } + } + + // No source we can't keep the mutex locked + impl std::error::Error for Error {} + + pub(crate) fn take(name: &str) -> Result, Error> { + match &*SYSTEMD_SOCKETS { + Ok(sockets) => Ok(sockets.take(name)), + Err(error) => Err(Error(error)) + } + } + + struct SystemdSockets(std::sync::Mutex>); + + impl SystemdSockets { + fn new() -> LibSystemdResult { + // MUST BE true FOR SAFETY!!! + let map = libsystemd::activation::receive_descriptors_with_names(/*unset env = */ true)?.into_iter().map(|(fd, name)| (name, fd)).collect(); + Ok(SystemdSockets(Mutex::new(map))) + } + + fn take(&self, name: &str) -> Option { + // MUST remove THE SOCKET FOR SAFETY!!! + self.0.lock().expect("poisoned mutex").remove(name) + } + } + + lazy_static::lazy_static! { + // We don't panic in order to let the application handle the error later + static ref SYSTEMD_SOCKETS: Result> = SystemdSockets::new().map_err(Mutex::new); + } +} + +/// Socket address that can be an ordinary address or a systemd socket +/// +/// This is the core type of this crate that abstracts possible addresses. +/// It can be (fallibly) converted from various types of strings or deserialized with `serde`. +/// After it's created, it can be bound as `TcpListener` from `std` or even `tokio` or `async_std` +/// if the appropriate feature is enabled. +/// +/// Optional dependencies on `parse_arg` and `serde` make it trivial to use with +/// [`configure_me`](https://crates.io/crates/configure_me). +#[derive(Debug)] +#[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))] +pub struct SocketAddr(SocketAddrInner); + +impl SocketAddr { + /// Creates `std::net::TcpListener` + /// + /// This method either `binds` the socket, if the address was provided or uses systemd socket + /// if the socket name was provided. + pub fn bind(self) -> Result { + self._bind().map(|(socket, _)| socket) + } + + /// Creates `tokio::net::TcpListener` + /// + /// To be specific, it binds the socket and converts it to `tokio` 0.2 socket. + /// + /// This method either `binds` the socket, if the address was provided or uses systemd socket + /// if the socket name was provided. + #[cfg(feature = "tokio_0_2")] + pub fn bind_tokio_0_2(self) -> Result { + let (socket, addr) = self._bind()?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + } + + /// Creates `tokio::net::TcpListener` + /// + /// To be specific, it binds the socket and converts it to `tokio` 0.3 socket. + /// + /// This method either `binds` the socket, if the address was provided or uses systemd socket + /// if the socket name was provided. + #[cfg(feature = "tokio_0_3")] + pub fn bind_tokio_0_3(self) -> Result { + let (socket, addr) = self._bind()?; + socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into()) + } + + /// Creates `async_std::net::TcpListener` + /// + /// To be specific, it binds the socket and converts it to `async_std` socket. + /// + /// This method either `binds` the socket, if the address was provided or uses systemd socket + /// if the socket name was provided. + #[cfg(feature = "async-std")] + pub fn bind_async_std(self) -> Result { + let (socket, _) = self._bind()?; + Ok(socket.into()) + } + + // We can't impl + Into> TryFrom for SocketAddr because of orphan + // rules. + fn try_from_generic<'a, T>(string: T) -> Result where T: 'a + std::ops::Deref + Into { + if string.starts_with(SYSTEMD_PREFIX) { + match string[SYSTEMD_PREFIX.len()..].chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') { + None => Ok(SocketAddr(SocketAddrInner::Systemd(string.into()))), + Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: string.into(), c, pos, }.into()), + } + } else { + Ok(string.parse().map(SocketAddrInner::Ordinary).map(SocketAddr).map_err(ParseErrorInner::SocketAddr)?) + } + } + + fn _bind(self) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> { + match self.0 { + SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) { + Ok(socket) => Ok((socket, SocketAddrInner::Ordinary(addr))), + Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()), + }, + SocketAddrInner::Systemd(socket_name) => { + use libsystemd::activation::IsType; + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let socket = systemd_sockets::take(&socket_name[SYSTEMD_PREFIX.len()..]).map_err(BindErrorInner::ReceiveDescriptors)?; + // Safety: The environment variable is unset, so that no other calls can get the + // descriptors. The descriptors are taken from the map, not cloned, so they can't + // be duplicated. + unsafe { + // match instead of combinators to avoid cloning socket_name + match socket { + Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))), + Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()), + None => Err(BindErrorInner::MissingDescriptor(socket_name).into()) + } + } + }, + } + } +} + +/// Displays the address in format that can be parsed again. +/// +/// **Important: While I don't expect this impl to change, don't rely on it!** +/// It should be used mostly for debugging/logging. +impl fmt::Display for SocketAddr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl fmt::Display for SocketAddrInner { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f), + SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f), + } + } +} + +// PartialEq for testing, I'm not convinced it should be exposed +#[derive(Debug, PartialEq)] +enum SocketAddrInner { + Ordinary(std::net::SocketAddr), + Systemd(String), +} + +const SYSTEMD_PREFIX: &str = "systemd://"; + +impl std::str::FromStr for SocketAddr { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + SocketAddr::try_from_generic(s) + } +} + +impl<'a> TryFrom<&'a str> for SocketAddr { + type Error = ParseError; + + fn try_from(s: &'a str) -> Result { + SocketAddr::try_from_generic(s) + } +} + +impl TryFrom for SocketAddr { + type Error = ParseError; + + fn try_from(s: String) -> Result { + SocketAddr::try_from_generic(s) + } +} + +impl<'a> TryFrom<&'a OsStr> for SocketAddr { + type Error = ParseOsStrError; + + fn try_from(s: &'a OsStr) -> Result { + s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into) + } +} + +impl TryFrom for SocketAddr { + type Error = ParseOsStrError; + + fn try_from(s: OsString) -> Result { + s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into) + } +} + +#[cfg(feature = "serde")] +impl<'a> TryFrom> for SocketAddr { + type Error = ParseError; + + fn try_from(s: serde_str_helpers::DeserBorrowStr<'a>) -> Result { + SocketAddr::try_from_generic(std::borrow::Cow::from(s)) + } +} + +#[cfg(feature = "parse_arg")] +impl parse_arg::ParseArg for SocketAddr { + type Error = ParseOsStrError; + + fn describe_type(mut writer: W) -> fmt::Result { + std::net::SocketAddr::describe_type(&mut writer)?; + write!(writer, " or a systemd socket name prefixed with systemd://") + } + + fn parse_arg(arg: &OsStr) -> Result { + arg.try_into() + } + + fn parse_owned_arg(arg: OsString) -> Result { + arg.try_into() + } +} + +#[cfg(test)] +mod tests { + use super::{SocketAddr, SocketAddrInner}; + + #[test] + fn parse_ordinary() { + assert_eq!("127.0.0.1:42".parse::().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into())); + } + + #[test] + fn parse_systemd() { + assert_eq!("systemd://foo".parse::().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned())); + } +}