view src/conv.rs @ 76:e58d24849e82

Add Message::set_error to quickly answer a question with an error.
author Paul Fisher <paul@pfish.zone>
date Sat, 07 Jun 2025 18:55:27 -0400
parents c7c596e6388f
children 351bdc13005e
line wrap: on
line source

//! The PAM conversation and associated Stuff.

// Temporarily allowed until we get the actual conversation functions hooked up.
#![allow(dead_code)]

use crate::constants::Result;
use crate::ErrorCode;
use secure_string::SecureString;
use std::cell::Cell;

/// The types of message and request that can be sent to a user.
///
/// The data within each enum value is the prompt (or other information)
/// that will be presented to the user.
#[non_exhaustive]
pub enum Message<'a> {
    MaskedPrompt(&'a MaskedPrompt<'a>),
    Prompt(&'a Prompt<'a>),
    RadioPrompt(&'a RadioPrompt<'a>),
    BinaryPrompt(&'a BinaryPrompt<'a>),
    InfoMsg(&'a InfoMsg<'a>),
    ErrorMsg(&'a ErrorMsg<'a>),
}

impl Message<'_> {
    /// Sets an error answer on this question, without having to inspect it.
    ///
    /// Use this as a default match case:
    ///
    /// ```
    /// use nonstick::conv::{Message, QAndA};
    /// use nonstick::ErrorCode;
    ///
    /// fn cant_respond(message: Message) {
    ///     match message {
    ///         Message::InfoMsg(i) => {
    ///             eprintln!("fyi, {}", i.question());
    ///             i.set_answer(Ok(()))
    ///         }
    ///         Message::ErrorMsg(e) => {
    ///             eprintln!("ERROR: {}", e.question());
    ///             e.set_answer(Ok(()))
    ///         }
    ///         // We can't answer any questions.
    ///         other => other.set_error(ErrorCode::ConversationError),
    ///     }
    /// }
    pub fn set_error(&self, err: ErrorCode) {
        match self {
            Message::MaskedPrompt(m) => m.set_answer(Err(err)),
            Message::Prompt(m) => m.set_answer(Err(err)),
            Message::RadioPrompt(m) => m.set_answer(Err(err)),
            Message::BinaryPrompt(m) => m.set_answer(Err(err)),
            Message::InfoMsg(m) => m.set_answer(Err(err)),
            Message::ErrorMsg(m) => m.set_answer(Err(err)),
        }
    }
}

/// A question-and-answer pair that can be communicated in a [`Conversation`].
///
/// The asking side creates a `QAndA`, then converts it to a [`Message`]
/// and sends it via a [`Conversation`]. The Conversation then retrieves
/// the answer to the question (if needed) and sets the response.
/// Once control returns to the asker, the asker gets the answer from this
/// `QAndA` and uses it however it wants.
///
/// For a more detailed explanation of how this works,
/// see [`Conversation::communicate`].
pub trait QAndA<'a> {
    /// The type of the content of the question.
    type Question: Copy;
    /// The type of the answer to the question.
    type Answer;

    /// Converts this Q-and-A pair into a [`Message`] for the [`Conversation`].
    fn message(&self) -> Message;

    /// The contents of the question being asked.
    ///
    /// For instance, this might say `"Username:"` to prompt the user
    /// for their name.
    fn question(&self) -> Self::Question;

    /// Sets the answer to the question.
    ///
    /// The [`Conversation`] implementation calls this to set the answer.
    /// The conversation should *always call this function*, even for messages
    /// that don't have "an answer" (like error or info messages).
    fn set_answer(&self, answer: Result<Self::Answer>);

    /// Gets the answer to the question.
    fn answer(self) -> Result<Self::Answer>;
}

macro_rules! q_and_a {
    ($name:ident<'a, Q=$qt:ty, A=$at:ty>, $($doc:literal)*) => {
        $(
            #[doc = $doc]
        )*
        pub struct $name<'a> {
            q: $qt,
            a: Cell<Result<$at>>,
        }

        impl<'a> QAndA<'a> for $name<'a> {
            type Question = $qt;
            type Answer = $at;

            fn question(&self) -> Self::Question {
                self.q
            }

            fn set_answer(&self, answer: Result<Self::Answer>) {
                self.a.set(answer)
            }

            fn answer(self) -> Result<Self::Answer> {
                self.a.into_inner()
            }

            fn message(&self) -> Message {
                Message::$name(self)
            }
        }
    };
}

macro_rules! ask {
    ($t:ident) => {
        impl<'a> $t<'a> {
            #[doc = concat!("Creates a `", stringify!($t), "` to be sent to the user.")]
            fn ask(question: &'a str) -> Self {
                Self {
                    q: question,
                    a: Cell::new(Err(ErrorCode::ConversationError)),
                }
            }
        }
    };
}

q_and_a!(
    MaskedPrompt<'a, Q=&'a str, A=SecureString>,
    "Asks the user for data and does not echo it back while being entered."
    ""
    "In other words, a password entry prompt."
);
ask!(MaskedPrompt);

q_and_a!(
    Prompt<'a, Q=&'a str, A=String>,
    "Asks the user for data."
    ""
    "This is the normal \"ask a person a question\" prompt."
    "When the user types, their input will be shown to them."
    "It can be used for things like usernames."
);
ask!(Prompt);

q_and_a!(
    RadioPrompt<'a, Q=&'a str, A=String>,
    "Asks the user for \"radio button\"–style data. (Linux-PAM extension)"
    ""
    "This message type is theoretically useful for \"yes/no/maybe\""
    "questions, but nowhere in the documentation is it specified"
    "what the format of the answer will be, or how this should be shown."
);
ask!(RadioPrompt);

q_and_a!(
    BinaryPrompt<'a, Q=BinaryQuestion<'a>, A=BinaryData>,
    "Asks for binary data. (Linux-PAM extension)"
    ""
    "This sends a binary message to the client application."
    "It can be used to communicate with non-human logins,"
    "or to enable things like security keys."
    ""
    "The `data_type` tag is a value that is simply passed through"
    "to the application. PAM does not define any meaning for it."
);
impl<'a> BinaryPrompt<'a> {
    /// Creates a prompt for the given binary data.
    ///
    /// The `data_type` is a tag you can use for communication between
    /// the module and the application. Its meaning is undefined by PAM.
    fn ask(data: &'a [u8], data_type: u8) -> Self {
        Self {
            q: BinaryQuestion { data, data_type },
            a: Cell::new(Err(ErrorCode::ConversationError)),
        }
    }
}

/// The contents of a question requesting binary data.
///
/// A borrowed version of [`BinaryData`].
#[derive(Copy, Clone, Debug)]
pub struct BinaryQuestion<'a> {
    data: &'a [u8],
    data_type: u8,
}

impl BinaryQuestion<'_> {
    /// Gets the data of this question.
    pub fn data(&self) -> &[u8] {
        self.data
    }

    /// Gets the "type" of this data.
    pub fn data_type(&self) -> u8 {
        self.data_type
    }
}

/// Owned binary data.
///
/// For borrowed data, see [`BinaryQuestion`].
/// You can take ownership of the stored data with `.into::<Vec<u8>>()`.
#[derive(Debug, PartialEq)]
pub struct BinaryData {
    data: Vec<u8>,
    data_type: u8,
}

impl BinaryData {
    /// Creates a `BinaryData` with the given contents and type.
    pub fn new(data: Vec<u8>, data_type: u8) -> Self {
        Self { data, data_type }
    }
    /// A borrowed view of the data here.
    pub fn data(&self) -> &[u8] {
        &self.data
    }
    /// The type of the data stored in this.
    pub fn data_type(&self) -> u8 {
        self.data_type
    }
}

impl From<BinaryData> for Vec<u8> {
    /// Takes ownership of the data stored herein.
    fn from(value: BinaryData) -> Self {
        value.data
    }
}

q_and_a!(
    InfoMsg<'a, Q = &'a str, A = ()>,
    "A message containing information to be passed to the user."
    ""
    "While this does not have an answer, [`Conversation`] implementations"
    "should still call [`set_answer`][`QAndA::set_answer`] to verify that"
    "the message has been displayed (or actively discarded)."
);
impl<'a> InfoMsg<'a> {
    /// Creates an informational message to send to the user.
    fn new(message: &'a str) -> Self {
        Self {
            q: message,
            a: Cell::new(Err(ErrorCode::ConversationError)),
        }
    }
}

q_and_a!(
    ErrorMsg<'a, Q = &'a str, A = ()>,
    "An error message to be passed to the user."
    ""
    "While this does not have an answer, [`Conversation`] implementations"
    "should still call [`set_answer`][`QAndA::set_answer`] to verify that"
    "the message has been displayed (or actively discarded)."

);
impl<'a> ErrorMsg<'a> {
    /// Creates an error message to send to the user.
    fn new(message: &'a str) -> Self {
        Self {
            q: message,
            a: Cell::new(Err(ErrorCode::ConversationError)),
        }
    }
}

/// A channel for PAM modules to request information from the user.
///
/// This trait is used by both applications and PAM modules:
///
/// - Applications implement Conversation and provide a user interface
///   to allow the user to respond to PAM questions.
/// - Modules call a Conversation implementation to request information
///   or send information to the user.
pub trait Conversation {
    /// Sends messages to the user.
    ///
    /// The returned Vec of messages always contains exactly as many entries
    /// as there were messages in the request; one corresponding to each.
    ///
    /// TODO: write detailed documentation about how to use this.
    fn communicate(&mut self, messages: &[Message]);
}

/// Turns a simple function into a [`Conversation`].
///
/// This can be used to wrap a free-floating function for use as a
/// Conversation:
///
/// ```
/// use nonstick::conv::{Conversation, Message, conversation_func};
/// mod some_library {
/// #    use nonstick::Conversation;
///     pub fn get_auth_data(conv: &mut impl Conversation) { /* ... */ }
/// }
///
/// fn my_terminal_prompt(messages: &[Message]) {
///     // ...
/// #    todo!()
/// }
///
/// fn main() {
///     some_library::get_auth_data(&mut conversation_func(my_terminal_prompt));
/// }
/// ```
pub fn conversation_func(func: impl FnMut(&[Message])) -> impl Conversation {
    Convo(func)
}

struct Convo<C: FnMut(&[Message])>(C);

impl<C: FnMut(&[Message])> Conversation for Convo<C> {
    fn communicate(&mut self, messages: &[Message]) {
        self.0(messages)
    }
}

/// A conversation trait for asking or answering one question at a time.
///
/// An implementation of this is provided for any [`Conversation`],
/// or a PAM application can implement this trait and handle messages
/// one at a time.
///
/// For example, to use a `Conversation` as a `SimpleConversation`:
///
/// ```
/// # use nonstick::{Conversation, Result};
/// # use secure_string::SecureString;
/// // Bring this trait into scope to get `masked_prompt`, among others.
/// use nonstick::SimpleConversation;
///
/// fn ask_for_token(convo: &mut impl Conversation) -> Result<SecureString> {
///     convo.masked_prompt("enter your one-time token")
/// }
/// ```
///
/// or to use a `SimpleConversation` as a `Conversation`:
///
/// ```
/// use secure_string::SecureString;
/// use nonstick::{Conversation, SimpleConversation};
/// # use nonstick::{BinaryData, Result};
/// mod some_library {
/// #    use nonstick::Conversation;
///     pub fn get_auth_data(conv: &mut impl Conversation) { /* ... */ }
/// }
///
/// struct MySimpleConvo { /* ... */ }
/// # impl MySimpleConvo { fn new() -> Self { Self{} } }
///
/// impl SimpleConversation for MySimpleConvo {
///     // ...
/// # fn prompt(&mut self, request: &str) -> Result<String> {
/// #     todo!()
/// # }
/// #
/// # fn masked_prompt(&mut self, request: &str) -> Result<SecureString> {
/// #     todo!()
/// # }
/// #
/// # fn radio_prompt(&mut self, request: &str) -> Result<String> {
/// #     todo!()
/// # }
/// #
/// # fn error_msg(&mut self, message: &str) {
/// #     todo!()
/// # }
/// #
/// # fn info_msg(&mut self, message: &str) {
/// #     todo!()
/// # }
/// #
/// # fn binary_prompt(&mut self, data: &[u8], data_type: u8) -> Result<BinaryData> {
/// #     todo!()
/// # }
/// }
///
/// fn main() {
///     let mut simple = MySimpleConvo::new();
///     some_library::get_auth_data(&mut simple.as_conversation())
/// }
/// ```
pub trait SimpleConversation {
    /// Lets you use this simple conversation as a full [Conversation].
    ///
    /// The wrapper takes each message received in [`Conversation::communicate`]
    /// and passes them one-by-one to the appropriate method,
    /// then collects responses to return.
    fn as_conversation(&mut self) -> Demux<Self>
    where
        Self: Sized,
    {
        Demux(self)
    }
    /// Prompts the user for something.
    fn prompt(&mut self, request: &str) -> Result<String>;
    /// Prompts the user for something, but hides what the user types.
    fn masked_prompt(&mut self, request: &str) -> Result<SecureString>;
    /// Prompts the user for a yes/no/maybe conditional (a Linux-PAM extension).
    ///
    /// PAM documentation doesn't define the format of the response.
    fn radio_prompt(&mut self, request: &str) -> Result<String>;
    /// Alerts the user to an error.
    fn error_msg(&mut self, message: &str);
    /// Sends an informational message to the user.
    fn info_msg(&mut self, message: &str);
    /// Requests binary data from the user (a Linux-PAM extension).
    fn binary_prompt(&mut self, data: &[u8], data_type: u8) -> Result<BinaryData>;
}

macro_rules! conv_fn {
    ($fn_name:ident($($param:ident: $pt:ty),+) -> $resp_type:ty { $ask:path }) => {
        fn $fn_name(&mut self, $($param: $pt),*) -> Result<$resp_type> {
            let prompt = $ask($($param),*);
            self.communicate(&[prompt.message()]);
            prompt.answer()
        }
    };
    ($fn_name:ident($($param:ident: $pt:ty),+) { $ask:path }) => {
        fn $fn_name(&mut self, $($param: $pt),*) {
            self.communicate(&[$ask($($param),*).message()]);
        }
    };
}

impl<C: Conversation> SimpleConversation for C {
    conv_fn!(prompt(message: &str) -> String { Prompt::ask });
    conv_fn!(masked_prompt(message: &str) -> SecureString { MaskedPrompt::ask });
    conv_fn!(radio_prompt(message: &str) -> String { RadioPrompt::ask });
    conv_fn!(error_msg(message: &str) { ErrorMsg::new });
    conv_fn!(info_msg(message: &str) { InfoMsg::new });
    conv_fn!(binary_prompt(data: &[u8], data_type: u8) -> BinaryData { BinaryPrompt::ask });
}

/// A [`Conversation`] which asks the questions one at a time.
///
/// This is automatically created by [`SimpleConversation::as_conversation`].
pub struct Demux<'a, SC: SimpleConversation>(&'a mut SC);

impl<SC: SimpleConversation> Conversation for Demux<'_, SC> {
    fn communicate(&mut self, messages: &[Message]) {
        for msg in messages {
            match msg {
                Message::Prompt(prompt) => prompt.set_answer(self.0.prompt(prompt.question())),
                Message::MaskedPrompt(prompt) => {
                    prompt.set_answer(self.0.masked_prompt(prompt.question()))
                }
                Message::RadioPrompt(prompt) => {
                    prompt.set_answer(self.0.radio_prompt(prompt.question()))
                }
                Message::InfoMsg(prompt) => {
                    self.0.info_msg(prompt.question());
                    prompt.set_answer(Ok(()))
                }
                Message::ErrorMsg(prompt) => {
                    self.0.error_msg(prompt.question());
                    prompt.set_answer(Ok(()))
                }
                Message::BinaryPrompt(prompt) => {
                    let q = prompt.question();
                    prompt.set_answer(self.0.binary_prompt(q.data, q.data_type))
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::{
        BinaryPrompt, Conversation, ErrorMsg, InfoMsg, MaskedPrompt, Message, Prompt, QAndA,
        RadioPrompt, Result, SecureString, SimpleConversation,
    };
    use crate::constants::ErrorCode;
    use crate::BinaryData;

    #[test]
    fn test_demux() {
        #[derive(Default)]
        struct DemuxTester {
            error_ran: bool,
            info_ran: bool,
        }

        impl SimpleConversation for DemuxTester {
            fn prompt(&mut self, request: &str) -> Result<String> {
                match request {
                    "what" => Ok("whatwhat".to_owned()),
                    "give_err" => Err(ErrorCode::PermissionDenied),
                    _ => panic!("unexpected prompt!"),
                }
            }
            fn masked_prompt(&mut self, request: &str) -> Result<SecureString> {
                assert_eq!("reveal", request);
                Ok(SecureString::from("my secrets"))
            }
            fn radio_prompt(&mut self, request: &str) -> Result<String> {
                assert_eq!("channel?", request);
                Ok("zero".to_owned())
            }
            fn error_msg(&mut self, message: &str) {
                self.error_ran = true;
                assert_eq!("whoopsie", message);
            }
            fn info_msg(&mut self, message: &str) {
                self.info_ran = true;
                assert_eq!("did you know", message);
            }
            fn binary_prompt(&mut self, data: &[u8], data_type: u8) -> Result<BinaryData> {
                assert_eq!(&[10, 9, 8], data);
                assert_eq!(66, data_type);
                Ok(BinaryData::new(vec![5, 5, 5], 5))
            }
        }

        let mut tester = DemuxTester::default();

        let what = Prompt::ask("what");
        let pass = MaskedPrompt::ask("reveal");
        let err = ErrorMsg::new("whoopsie");
        let info = InfoMsg::new("did you know");
        let has_err = Prompt::ask("give_err");

        let mut conv = tester.as_conversation();

        // Basic tests.

        conv.communicate(&[
            what.message(),
            pass.message(),
            err.message(),
            info.message(),
            has_err.message(),
        ]);

        assert_eq!("whatwhat", what.answer().unwrap());
        assert_eq!(SecureString::from("my secrets"), pass.answer().unwrap());
        assert_eq!(Ok(()), err.answer());
        assert_eq!(Ok(()), info.answer());
        assert_eq!(ErrorCode::PermissionDenied, has_err.answer().unwrap_err());
        assert!(tester.error_ran);
        assert!(tester.info_ran);

        // Test the Linux extensions separately.

        let mut conv = tester.as_conversation();

        let radio = RadioPrompt::ask("channel?");
        let bin = BinaryPrompt::ask(&[10, 9, 8], 66);
        conv.communicate(&[radio.message(), bin.message()]);

        assert_eq!("zero", radio.answer().unwrap());
        assert_eq!(BinaryData::new(vec![5, 5, 5], 5), bin.answer().unwrap());
    }

    fn test_mux() {
        struct MuxTester;

        impl Conversation for MuxTester {
            fn communicate(&mut self, messages: &[Message]) {
                if let [msg] = messages {
                    match *msg {
                        Message::InfoMsg(info) => {
                            assert_eq!("let me tell you", info.question());
                            info.set_answer(Ok(()))
                        }
                        Message::ErrorMsg(error) => {
                            assert_eq!("oh no", error.question());
                            error.set_answer(Ok(()))
                        }
                        Message::Prompt(prompt) => prompt.set_answer(match prompt.question() {
                            "should_err" => Err(ErrorCode::PermissionDenied),
                            "question" => Ok("answer".to_owned()),
                            other => panic!("unexpected question {other:?}"),
                        }),
                        Message::MaskedPrompt(ask) => {
                            assert_eq!("password!", ask.question());
                            ask.set_answer(Ok("open sesame".into()))
                        }
                        Message::BinaryPrompt(prompt) => {
                            assert_eq!(&[1, 2, 3], prompt.question().data);
                            assert_eq!(69, prompt.question().data_type);
                            prompt.set_answer(Ok(BinaryData::new(vec![3, 2, 1], 42)))
                        }
                        Message::RadioPrompt(ask) => {
                            assert_eq!("radio?", ask.question());
                            ask.set_answer(Ok("yes".to_owned()))
                        }
                    }
                } else {
                    panic!(
                        "there should only be one message, not {len}",
                        len = messages.len()
                    )
                }
            }
        }

        let mut tester = MuxTester;

        assert_eq!("answer", tester.prompt("question").unwrap());
        assert_eq!(
            SecureString::from("open sesame"),
            tester.masked_prompt("password!").unwrap()
        );
        tester.error_msg("oh no");
        tester.info_msg("let me tell you");
        {
            assert_eq!("yes", tester.radio_prompt("radio?").unwrap());
            assert_eq!(
                BinaryData::new(vec![3, 2, 1], 42),
                tester.binary_prompt(&[1, 2, 3], 69).unwrap(),
            )
        }
        assert_eq!(
            ErrorCode::BufferError,
            tester.prompt("should_error").unwrap_err(),
        );
        assert_eq!(
            ErrorCode::ConversationError,
            tester.masked_prompt("return_wrong_type").unwrap_err()
        )
    }
}