view libpam-sys/libpam-sys-impls/src/lib.rs @ 110:2346fd501b7a

Add tests for constants and do other macro niceties. - Adds tests for all the constants. Pretty sweet. - Moves documentation for cfg-pam-impl macro to `libpam-sys`. - Renames `Illumos` to `Sun`. - other stuff
author Paul Fisher <paul@pfish.zone>
date Sun, 29 Jun 2025 02:15:46 -0400
parents bb465393621f
children 178310336596
line wrap: on
line source

use proc_macro as pm;
use proc_macro2::{Delimiter, Group, Literal, Span, TokenStream, TokenTree};
use quote::{format_ident, quote};
use std::fmt::Display;
use std::str::FromStr;
use syn::Lit;

// For documentation on this, see the `libpam-sys` crate.
#[proc_macro_attribute]
pub fn cfg_pam_impl(attr: pm::TokenStream, item: pm::TokenStream) -> pm::TokenStream {
    Predicate::parse(attr.into(), None)
        .map(|p| {
            if p.matches(pam_impl_str()) {
                item
            } else {
                pm::TokenStream::new()
            }
        })
        .unwrap_or_else(|e| syn::Error::from(e).into_compile_error().into())
}

/// Outputs the `PamImpl` enum and `LIBPAMSYS_IMPL` constant.
/// For use only in `libpam-sys`.
///
/// The tokens passed into the macro are pasted immediately before the enum.
#[proc_macro]
pub fn __pam_impl_enum__(data: pm::TokenStream) -> pm::TokenStream {
    let variant = format_ident!("{}", pam_impl_str());
    TokenStream::from_iter([
        data.into(),
        TokenStream::from_str(include_str!(concat!(env!("OUT_DIR"), "/pam_impl_enum.rs"))).unwrap(),
        quote!(
            impl PamImpl {
                #[doc = concat!("The PAM implementation this was built for (currently `", stringify!(#variant), ")`.")]
                pub const CURRENT: Self = Self::#variant;
            }
        ),
    ]).into()
}

/// The name of the PAM implementation. For use only in `libpam-sys`.
#[proc_macro]
pub fn __pam_impl_name__(data: pm::TokenStream) -> pm::TokenStream {
    if !data.is_empty() {
        panic!("pam_impl_name! does not take any input")
    }
    pm::TokenTree::Literal(pm::Literal::string(pam_impl_str())).into()
}

fn pam_impl_str() -> &'static str {
    env!("LIBPAMSYS_IMPL")
}

#[derive(Debug)]
enum Error {
    WithSpan(syn::Error),
    WithoutSpan(String),
}

impl Error {
    fn new<D: Display>(span: Option<Span>, msg: D) -> Self {
        match span {
            Some(span) => syn::Error::new(span, msg).into(),
            None => Self::WithoutSpan(msg.to_string()),
        }
    }
}

impl From<syn::Error> for Error {
    fn from(value: syn::Error) -> Self {
        Self::WithSpan(value)
    }
}

impl From<String> for Error {
    fn from(value: String) -> Self {
        Self::WithoutSpan(value)
    }
}

impl From<Error> for syn::Error {
    fn from(value: Error) -> Self {
        match value {
            Error::WithSpan(e) => e,
            Error::WithoutSpan(s) => syn::Error::new(Span::call_site(), s),
        }
    }
}

type Result<T> = std::result::Result<T, Error>;

#[derive(Debug)]
enum Predicate {
    Literal(String),
    Any(Vec<String>),
    Not(Box<Predicate>),
}

impl Predicate {
    fn matches(&self, value: &str) -> bool {
        match self {
            Self::Literal(literal) => value == literal,
            Self::Not(pred) => !pred.matches(value),
            Self::Any(options) => options.iter().any(|s| s == value),
        }
    }

    fn parse(stream: TokenStream, span: Option<Span>) -> Result<Self> {
        let mut iter = stream.into_iter();
        let pred = match iter.next() {
            None => return error(span, "a PAM implementation predicate must be provided"),
            Some(TokenTree::Literal(lit)) => Self::Literal(Self::string_lit(lit)?),
            Some(TokenTree::Ident(id)) => {
                let next = Self::parens(iter.next(), span)?;
                match id.to_string().as_str() {
                    "not" => Self::Not(Box::new(Self::parse(next.stream(), Some(next.span()))?)),
                    "any" => Self::Any(Self::parse_any(next)?),
                    _ => return unexpected(&id.into(), "\"not\" or \"any\""),
                }
            }
            Some(other) => return unexpected(&other, "\"not\", \"any\", or a string literal"),
        };
        // Check for anything after. We only allow a comma and nothing else.
        if maybe_comma(iter.next())? {
            if let Some(next) = iter.next() {
                return unexpected(&next, "nothing");
            }
        }
        Ok(pred)
    }

    fn parens(tree: Option<TokenTree>, mut span: Option<Span>) -> Result<Group> {
        if let Some(tree) = tree {
            span = Some(tree.span());
            if let TokenTree::Group(g) = tree {
                if g.delimiter() == Delimiter::Parenthesis {
                    return Ok(g);
                }
            }
        }
        Err(Error::new(span, "expected function-call syntax"))
    }

    fn parse_any(g: Group) -> Result<Vec<String>> {
        let mut output = Vec::new();
        let mut iter = g.stream().into_iter();
        loop {
            match iter.next() {
                None => break,
                Some(TokenTree::Literal(lit)) => {
                    output.push(Self::string_lit(lit)?);
                    if !maybe_comma(iter.next())? {
                        break;
                    }
                }
                Some(other) => return unexpected(&other, "string literal"),
            }
        }
        Ok(output)
    }

    fn string_lit(lit: Literal) -> Result<String> {
        let tree: TokenTree = lit.clone().into();
        match syn::parse2::<Lit>(tree.into())? {
            Lit::Str(s) => Ok(s.value()),
            _ => unexpected(&lit.into(), "string literal"),
        }
    }
}

fn error<T, M: Display>(span: Option<Span>, message: M) -> Result<T> {
    Err(Error::new(span, message))
}

fn unexpected<T>(tree: &TokenTree, want: &str) -> Result<T> {
    error(
        Some(tree.span()),
        format!("expected {want}; got unexpected token {tree}"),
    )
}

fn maybe_comma(next: Option<TokenTree>) -> Result<bool> {
    match next {
        None => Ok(false),
        Some(tree) => {
            if let TokenTree::Punct(p) = &tree {
                if p.as_char() == ',' {
                    return Ok(true);
                }
            }
            unexpected(&tree, "',' or ')'")
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn parse(tree: TokenStream) -> Predicate {
        Predicate::parse(tree, None).unwrap()
    }

    #[test]
    fn test_parse() {
        macro_rules! cases {
            ($(($($i:tt)*)),* $(,)?) => { [ $( quote!($($i)*) ),* ] };
        }

        let good = cases![
            ("this"),
            (any("this", "that", "the other")),
            (not("the bees")),
            (not(any("of", "those"))),
            (not(not("saying it"))),
            (any("trailing", "comma", "allowed",)),
            ("even on a singleton",),
            (not("forbidden here either",)),
            (not(not(any("this", "is", "stupid"),),),),
        ];
        for tree in good {
            parse(tree);
        }
        let bad = cases![
            (),
            (wrong),
            (wheel::of::fortune),
            ("invalid", "syntax"),
            (any(any)),
            (any),
            (not),
            (not(any)),
            ("too many commas",,),
            (any("too", "many",, "commas")),
            (not("the commas",,,)),
            (9),
            (any("123", 8)),
            (not(666)),
        ];
        for tree in bad {
            Predicate::parse(tree, None).unwrap_err();
        }
    }

    #[test]
    fn test_match() {
        macro_rules! cases {
            ($(($e:expr, ($($i:tt)*))),* $(,)?) => {
                [$(($e, quote!($($i)*))),*]
            }
        }
        let matching = cases![
            ("Sun", (any("Sun", "OpenPam"))),
            ("OpenPam", (any("Sun", "OpenPam"))),
            ("LinuxPam", (not("OpenPam"))),
            ("MinimalOpenPam", (not("OpenPam"))),
            ("Other", (not(any("This", "That")))),
            ("OpenPam", (not(not("OpenPam")))),
            ("Anything", (not(any()))),
        ];
        for (good, tree) in matching {
            let pred = parse(tree);
            assert!(pred.matches(good))
        }

        let nonmatching = cases![
            ("LinuxPam", (not("LinuxPam"))),
            ("Sun", ("LinuxPam")),
            ("OpenPam", (any("LinuxPam", "Sun"))),
            ("One", (not(any("One", "Another")))),
            ("Negatory", (not(not("Affirmative")))),
            ("MinimalOpenPam", ("OpenPam")),
            ("OpenPam", (("MinimalOpenPam"))),
        ];
        for (bad, tree) in nonmatching {
            let pred = parse(tree);
            assert!(!pred.matches(bad))
        }
    }
}