diff src/libpam/question.rs @ 78:002adfb98c5c

Rename files, reorder structs, remove annoying BorrowedBinaryData type. This is basically a cleanup change. Also it adds tests. - Renames the files with Questions and Answers to question and answer. - Reorders the structs in those files to put the important ones first. - Removes the BorrowedBinaryData type. It was a bad idea all along. Instead, we just use (&[u8], u8). - Adds some tests because I just can't help myself.
author Paul Fisher <paul@pfish.zone>
date Sun, 08 Jun 2025 03:48:40 -0400
parents src/libpam/message.rs@351bdc13005e
children 2128123b9406
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/libpam/question.rs	Sun Jun 08 03:48:40 2025 -0400
@@ -0,0 +1,374 @@
+//! Data and types dealing with PAM messages.
+
+use crate::constants::InvalidEnum;
+use crate::conv::{BinaryQAndA, ErrorMsg, InfoMsg, MaskedQAndA, Message, QAndA, RadioQAndA};
+use crate::libpam::conversation::OwnedMessage;
+use crate::libpam::memory;
+use crate::libpam::memory::{CBinaryData, Immovable};
+use crate::ErrorCode;
+use crate::Result;
+use num_derive::FromPrimitive;
+use num_traits::FromPrimitive;
+use std::ffi::{c_int, c_void, CStr};
+use std::result::Result as StdResult;
+use std::{iter, ptr, slice};
+
+/// Abstraction of a collection of questions to be sent in a PAM conversation.
+///
+/// The PAM C API conversation function looks like this:
+///
+/// ```c
+/// int pam_conv(
+///     int count,
+///     const struct pam_message **questions,
+///     struct pam_response **answers,
+///     void *appdata_ptr,
+/// )
+/// ```
+///
+/// On Linux-PAM and other compatible implementations, `questions`
+/// is treated as a pointer-to-pointers, like `int argc, char **argv`.
+/// (In this situation, the value of `Questions.indirect` is
+/// the pointer passed to `pam_conv`.)
+///
+/// ```text
+/// ╔═ Questions ═╗  points to  ┌─ Indirect ─┐       ╔═ Question ═╗
+/// ║ indirect ┄┄┄╫┄┄┄┄┄┄┄┄┄┄┄> │ base[0] ┄┄┄┼┄┄┄┄┄> ║ style      ║
+/// ║ count       ║             │ base[1] ┄┄┄┼┄┄┄╮   ║ data ┄┄┄┄┄┄╫┄┄> ...
+/// ╚═════════════╝             │ ...        │   ┆   ╚════════════╝
+///                                              ┆
+///                                              ┆    ╔═ Question ═╗
+///                                              ╰┄┄> ║ style      ║
+///                                                   ║ data ┄┄┄┄┄┄╫┄┄> ...
+///                                                   ╚════════════╝
+/// ```
+///
+/// On OpenPAM and other compatible implementations (like Solaris),
+/// `messages` is a pointer-to-pointer-to-array.  This appears to be
+/// the correct implementation as required by the XSSO specification.
+///
+/// ```text
+/// ╔═ Questions ═╗  points to  ┌─ Indirect ─┐       ╔═ Question[] ═╗
+/// ║ indirect ┄┄┄╫┄┄┄┄┄┄┄┄┄┄┄> │ base ┄┄┄┄┄┄┼┄┄┄┄┄> ║ style        ║
+/// ║ count       ║             └────────────┘       ║ data ┄┄┄┄┄┄┄┄╫┄┄> ...
+/// ╚═════════════╝                                  ╟──────────────╢
+///                                                  ║ style        ║
+///                                                  ║ data ┄┄┄┄┄┄┄┄╫┄┄> ...
+///                                                  ╟──────────────╢
+///                                                  ║ ...          ║
+/// ```
+///
+/// ***THIS LIBRARY CURRENTLY SUPPORTS ONLY LINUX-PAM.***
+pub struct Questions {
+    /// An indirection to the questions themselves, stored on the C heap.
+    indirect: *mut Indirect,
+    /// The number of questions.
+    count: usize,
+}
+
+impl Questions {
+    /// Stores the provided questions on the C heap.
+    pub fn new(messages: &[Message]) -> Result<Self> {
+        let count = messages.len();
+        let mut ret = Self {
+            indirect: Indirect::alloc(count),
+            count,
+        };
+        // Even if we fail partway through this, all our memory will be freed.
+        for (question, message) in iter::zip(ret.iter_mut(), messages) {
+            question.fill(message)?
+        }
+        Ok(ret)
+    }
+
+    /// The pointer to the thing with the actual list.
+    pub fn indirect(&self) -> *const Indirect {
+        self.indirect
+    }
+
+    pub fn iter(&self) -> impl Iterator<Item = &Question> {
+        // SAFETY: we're iterating over an amount we know.
+        unsafe { (*self.indirect).iter(self.count) }
+    }
+    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Question> {
+        // SAFETY: we're iterating over an amount we know.
+        unsafe { (*self.indirect).iter_mut(self.count) }
+    }
+}
+
+impl Drop for Questions {
+    fn drop(&mut self) {
+        // SAFETY: We are valid and have a valid pointer.
+        // Once we're done, everything will be safe.
+        unsafe {
+            if let Some(indirect) = self.indirect.as_mut() {
+                indirect.free(self.count)
+            }
+            memory::free(self.indirect);
+            self.indirect = ptr::null_mut();
+        }
+    }
+}
+
+/// An indirect reference to messages.
+///
+/// This is kept separate to provide a place where we can separate
+/// the pointer-to-pointer-to-list from pointer-to-list-of-pointers.
+#[repr(transparent)]
+pub struct Indirect {
+    base: [*mut Question; 0],
+    _marker: Immovable,
+}
+
+impl Indirect {
+    /// Allocates memory for this indirector and all its members.
+    fn alloc(count: usize) -> *mut Self {
+        // SAFETY: We're only allocating, and when we're done,
+        // everything will be in a known-good state.
+        let me_ptr: *mut Indirect = memory::calloc::<Question>(count).cast();
+        unsafe {
+            let me = &mut *me_ptr;
+            let ptr_list = slice::from_raw_parts_mut(me.base.as_mut_ptr(), count);
+            for entry in ptr_list {
+                *entry = memory::calloc(1);
+            }
+            me
+        }
+    }
+
+    /// Returns an iterator yielding the given number of messages.
+    ///
+    /// # Safety
+    ///
+    /// You have to provide the right count.
+    pub unsafe fn iter(&self, count: usize) -> impl Iterator<Item = &Question> {
+        (0..count).map(|idx| &**self.base.as_ptr().add(idx))
+    }
+
+    /// Returns a mutable iterator yielding the given number of messages.
+    ///
+    /// # Safety
+    ///
+    /// You have to provide the right count.
+    pub unsafe fn iter_mut(&mut self, count: usize) -> impl Iterator<Item = &mut Question> {
+        (0..count).map(|idx| &mut **self.base.as_mut_ptr().add(idx))
+    }
+
+    /// Frees everything this points to.
+    ///
+    /// # Safety
+    ///
+    /// You have to pass the right size.
+    unsafe fn free(&mut self, count: usize) {
+        let msgs = slice::from_raw_parts_mut(self.base.as_mut_ptr(), count);
+        for msg in msgs {
+            if let Some(msg) = msg.as_mut() {
+                msg.clear();
+            }
+            memory::free(*msg);
+            *msg = ptr::null_mut();
+        }
+    }
+}
+
+/// The C enum values for messages shown to the user.
+#[derive(Debug, PartialEq, FromPrimitive)]
+pub enum Style {
+    /// Requests information from the user; will be masked when typing.
+    PromptEchoOff = 1,
+    /// Requests information from the user; will not be masked.
+    PromptEchoOn = 2,
+    /// An error message.
+    ErrorMsg = 3,
+    /// An informational message.
+    TextInfo = 4,
+    /// Yes/No/Maybe conditionals. A Linux-PAM extension.
+    RadioType = 5,
+    /// For server–client non-human interaction.
+    ///
+    /// NOT part of the X/Open PAM specification.
+    /// A Linux-PAM extension.
+    BinaryPrompt = 7,
+}
+
+impl TryFrom<c_int> for Style {
+    type Error = InvalidEnum<Self>;
+    fn try_from(value: c_int) -> StdResult<Self, Self::Error> {
+        Self::from_i32(value).ok_or(value.into())
+    }
+}
+
+impl From<Style> for c_int {
+    fn from(val: Style) -> Self {
+        val as Self
+    }
+}
+
+/// A question sent by PAM or a module to an application.
+///
+/// PAM refers to this as a "message", but we call it a question
+/// to avoid confusion with [`Message`].
+///
+/// This question, and its internal data, is owned by its creator
+/// (either the module or PAM itself).
+#[repr(C)]
+pub struct Question {
+    /// The style of message to request.
+    style: c_int,
+    /// A description of the data requested.
+    ///
+    /// For most requests, this will be an owned [`CStr`], but for requests
+    /// with [`Style::BinaryPrompt`], this will be [`CBinaryData`]
+    /// (a Linux-PAM extension).
+    data: *mut c_void,
+    _marker: Immovable,
+}
+
+impl Question {
+    /// Replaces the contents of this question with the question
+    /// from the message.
+    pub fn fill(&mut self, msg: &Message) -> Result<()> {
+        let (style, data) = copy_to_heap(msg)?;
+        self.clear();
+        self.style = style as c_int;
+        self.data = data;
+        Ok(())
+    }
+
+    /// Gets this message's data pointer as a string.
+    ///
+    /// # Safety
+    ///
+    /// It's up to you to pass this only on types with a string value.
+    unsafe fn string_data(&self) -> Result<&str> {
+        if self.data.is_null() {
+            Ok("")
+        } else {
+            CStr::from_ptr(self.data.cast())
+                .to_str()
+                .map_err(|_| ErrorCode::ConversationError)
+        }
+    }
+
+    /// Gets this message's data pointer as borrowed binary data.
+    unsafe fn binary_data(&self) -> (&[u8], u8) {
+        self.data
+            .cast::<CBinaryData>()
+            .as_ref()
+            .map(Into::into)
+            .unwrap_or_default()
+    }
+
+    /// Zeroes out the data stored here.
+    fn clear(&mut self) {
+        // SAFETY: We either created this data or we got it from PAM.
+        // After this function is done, it will be zeroed out.
+        unsafe {
+            if let Ok(style) = Style::try_from(self.style) {
+                match style {
+                    Style::BinaryPrompt => {
+                        if let Some(d) = self.data.cast::<CBinaryData>().as_mut() {
+                            d.zero_contents()
+                        }
+                    }
+                    Style::TextInfo
+                    | Style::RadioType
+                    | Style::ErrorMsg
+                    | Style::PromptEchoOff
+                    | Style::PromptEchoOn => memory::zero_c_string(self.data.cast()),
+                }
+            };
+            memory::free(self.data);
+            self.data = ptr::null_mut();
+        }
+    }
+}
+
+impl<'a> TryFrom<&'a Question> for OwnedMessage<'a> {
+    type Error = ErrorCode;
+    fn try_from(question: &'a Question) -> Result<Self> {
+        let style: Style = question
+            .style
+            .try_into()
+            .map_err(|_| ErrorCode::ConversationError)?;
+        // SAFETY: In all cases below, we're matching the
+        let prompt = unsafe {
+            match style {
+                Style::PromptEchoOff => {
+                    Self::MaskedPrompt(MaskedQAndA::new(question.string_data()?))
+                }
+                Style::PromptEchoOn => Self::Prompt(QAndA::new(question.string_data()?)),
+                Style::ErrorMsg => Self::Error(ErrorMsg::new(question.string_data()?)),
+                Style::TextInfo => Self::Info(InfoMsg::new(question.string_data()?)),
+                Style::RadioType => Self::RadioPrompt(RadioQAndA::new(question.string_data()?)),
+                Style::BinaryPrompt => Self::BinaryPrompt(BinaryQAndA::new(question.binary_data())),
+            }
+        };
+        Ok(prompt)
+    }
+}
+
+/// Copies the contents of this message to the C heap.
+fn copy_to_heap(msg: &Message) -> Result<(Style, *mut c_void)> {
+    let alloc = |style, text| Ok((style, memory::malloc_str(text)?.cast()));
+    match *msg {
+        Message::MaskedPrompt(p) => alloc(Style::PromptEchoOff, p.question()),
+        Message::Prompt(p) => alloc(Style::PromptEchoOn, p.question()),
+        Message::RadioPrompt(p) => alloc(Style::RadioType, p.question()),
+        Message::Error(p) => alloc(Style::ErrorMsg, p.question()),
+        Message::Info(p) => alloc(Style::TextInfo, p.question()),
+        Message::BinaryPrompt(p) => {
+            let q = p.question();
+            Ok((
+                Style::BinaryPrompt,
+                CBinaryData::alloc(q)?.cast(),
+            ))
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+
+    use super::{MaskedQAndA, Questions, Result};
+    use crate::conv::{BinaryQAndA, ErrorMsg, InfoMsg, QAndA, RadioQAndA};
+    use crate::libpam::conversation::OwnedMessage;
+
+    #[test]
+    fn test_round_trip() {
+        let interrogation = Questions::new(&[
+            MaskedQAndA::new("hocus pocus").message(),
+            BinaryQAndA::new((&[5, 4, 3, 2, 1], 66)).message(),
+            QAndA::new("what").message(),
+            QAndA::new("who").message(),
+            InfoMsg::new("hey").message(),
+            ErrorMsg::new("gasp").message(),
+            RadioQAndA::new("you must choose").message(),
+        ])
+        .unwrap();
+        let indirect = interrogation.indirect();
+
+        let remade = unsafe { indirect.as_ref() }.unwrap();
+        let messages: Vec<OwnedMessage> = unsafe { remade.iter(interrogation.count) }
+            .map(TryInto::try_into)
+            .collect::<Result<_>>()
+            .unwrap();
+        let [masked, bin, what, who, hey, gasp, choose] = messages.try_into().unwrap();
+        macro_rules! assert_matches {
+            ($id:ident => $variant:path, $q:expr) => {
+                if let $variant($id) = $id {
+                    assert_eq!($q, $id.question());
+                } else {
+                    panic!("mismatched enum variant {x:?}", x = $id);
+                }
+            };
+        }
+        assert_matches!(masked => OwnedMessage::MaskedPrompt, "hocus pocus");
+        assert_matches!(bin => OwnedMessage::BinaryPrompt, (&[5, 4, 3, 2, 1][..], 66));
+        assert_matches!(what => OwnedMessage::Prompt, "what");
+        assert_matches!(who => OwnedMessage::Prompt, "who");
+        assert_matches!(hey => OwnedMessage::Info, "hey");
+        assert_matches!(gasp => OwnedMessage::Error, "gasp");
+        assert_matches!(choose => OwnedMessage::RadioPrompt, "you must choose");
+    }
+}