view src/libpam/question.rs @ 85:5e14bb093851

fix more openpam compat stuff
author Paul Fisher <paul@pfish.zone>
date Tue, 10 Jun 2025 02:42:29 -0400
parents 5aa1a010f1e8
children 05291b601f0a
line wrap: on
line source

//! Data and types dealing with PAM messages.

use crate::conv::{BinaryQAndA, ErrorMsg, InfoMsg, MaskedQAndA, Message, QAndA, RadioQAndA};
use crate::libpam::conversation::OwnedMessage;
use crate::libpam::memory;
use crate::libpam::memory::{CBinaryData, Immovable};
pub use crate::libpam::pam_ffi::{Question, Style};
use crate::ErrorCode;
use crate::Result;
use std::ffi::{c_void, CStr};
use std::{iter, ptr, slice};

/// Abstraction of a collection of questions to be sent in a PAM conversation.
///
/// The PAM C API conversation function looks like this:
///
/// ```c
/// int pam_conv(
///     int count,
///     const struct pam_message **questions,
///     struct pam_response **answers,
///     void *appdata_ptr,
/// )
/// ```
///
/// On Linux-PAM and other compatible implementations, `questions`
/// is treated as a pointer-to-pointers, like `int argc, char **argv`.
/// (In this situation, the value of `Questions.indirect` is
/// the pointer passed to `pam_conv`.)
///
/// ```text
/// ╔═ Questions ═╗  points to  ┌─ Indirect ─┐       ╔═ Question ═╗
/// ║ indirect ┄┄┄╫┄┄┄┄┄┄┄┄┄┄┄> │ base[0] ┄┄┄┼┄┄┄┄┄> ║ style      ║
/// ║ count       ║             │ base[1] ┄┄┄┼┄┄┄╮   ║ data ┄┄┄┄┄┄╫┄┄> ...
/// ╚═════════════╝             │ ...        │   ┆   ╚════════════╝
///                                              ┆
///                                              ┆    ╔═ Question ═╗
///                                              ╰┄┄> ║ style      ║
///                                                   ║ data ┄┄┄┄┄┄╫┄┄> ...
///                                                   ╚════════════╝
/// ```
///
/// On OpenPAM and other compatible implementations (like Solaris),
/// `messages` is a pointer-to-pointer-to-array.  This appears to be
/// the correct implementation as required by the XSSO specification.
///
/// ```text
/// ╔═ Questions ═╗  points to  ┌─ Indirect ─┐       ╔═ Question[] ═╗
/// ║ indirect ┄┄┄╫┄┄┄┄┄┄┄┄┄┄┄> │ base ┄┄┄┄┄┄┼┄┄┄┄┄> ║ style        ║
/// ║ count       ║             └────────────┘       ║ data ┄┄┄┄┄┄┄┄╫┄┄> ...
/// ╚═════════════╝                                  ╟──────────────╢
///                                                  ║ style        ║
///                                                  ║ data ┄┄┄┄┄┄┄┄╫┄┄> ...
///                                                  ╟──────────────╢
///                                                  ║ ...          ║
/// ```
pub struct GenericQuestions<I: IndirectTrait> {
    /// An indirection to the questions themselves, stored on the C heap.
    indirect: *mut I,
    /// The number of questions.
    count: usize,
}

impl<I: IndirectTrait> GenericQuestions<I> {
    /// Stores the provided questions on the C heap.
    pub fn new(messages: &[Message]) -> Result<Self> {
        let count = messages.len();
        let mut ret = Self {
            indirect: I::alloc(count),
            count,
        };
        // Even if we fail partway through this, all our memory will be freed.
        for (question, message) in iter::zip(ret.iter_mut(), messages) {
            question.fill(message)?
        }
        Ok(ret)
    }

    /// The pointer to the thing with the actual list.
    pub fn indirect(&self) -> *const *const Question {
        self.indirect.cast()
    }

    pub fn iter(&self) -> impl Iterator<Item = &Question> {
        // SAFETY: we're iterating over an amount we know.
        unsafe { (*self.indirect).iter(self.count) }
    }
    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Question> {
        // SAFETY: we're iterating over an amount we know.
        unsafe { (*self.indirect).iter_mut(self.count) }
    }
}

impl<I: IndirectTrait> Drop for GenericQuestions<I> {
    fn drop(&mut self) {
        // SAFETY: We are valid and have a valid pointer.
        // Once we're done, everything will be safe.
        unsafe {
            if let Some(indirect) = self.indirect.as_mut() {
                indirect.free_contents(self.count)
            }
            memory::free(self.indirect);
            self.indirect = ptr::null_mut();
        }
    }
}

/// The trait that each of the `Indirect` implementations implement.
///
/// Basically a slice but with more meat.
pub trait IndirectTrait {
    /// Converts a pointer into a borrowed `Self`.
    ///
    /// # Safety
    ///
    /// You have to provide a valid pointer.
    unsafe fn borrow_ptr<'a>(ptr: *const *const Question) -> Option<&'a Self>
    where
        Self: Sized,
    {
        ptr.cast::<Self>().as_ref()
    }

    /// Allocates memory for this indirector and all its members.
    fn alloc(count: usize) -> *mut Self;

    /// Returns an iterator yielding the given number of messages.
    ///
    /// # Safety
    ///
    /// You have to provide the right count.
    unsafe fn iter(&self, count: usize) -> impl Iterator<Item = &Question>;

    /// Returns a mutable iterator yielding the given number of messages.
    ///
    /// # Safety
    ///
    /// You have to provide the right count.
    unsafe fn iter_mut(&mut self, count: usize) -> impl Iterator<Item = &mut Question>;

    /// Frees everything this points to.
    ///
    /// # Safety
    ///
    /// You have to pass the right size.
    unsafe fn free_contents(&mut self, count: usize);
}

/// An indirect reference to messages.
///
/// This is kept separate to provide a place where we can separate
/// the pointer-to-pointer-to-list from pointer-to-list-of-pointers.
#[cfg(pam_impl = "linux-pam")]
pub type Indirect = LinuxPamIndirect;

/// An indirect reference to messages.
///
/// This is kept separate to provide a place where we can separate
/// the pointer-to-pointer-to-list from pointer-to-list-of-pointers.
#[cfg(not(pam_impl = "linux-pam"))]
pub type Indirect = XSsoIndirect;

pub type Questions = GenericQuestions<Indirect>;

/// The XSSO standard version of the indirection layer between Question and Questions.
#[repr(transparent)]
pub struct StandardIndirect {
    base: *mut Question,
    _marker: Immovable,
}

impl IndirectTrait for StandardIndirect {
    fn alloc(count: usize) -> *mut Self {
        let questions = memory::calloc(count);
        let me_ptr: *mut Self = memory::calloc(1);
        // SAFETY: We just allocated this, and we're putting a valid pointer in.
        unsafe {
            let me = &mut *me_ptr;
            me.base = questions;
        }
        me_ptr
    }

    unsafe fn iter(&self, count: usize) -> impl Iterator<Item = &Question> {
        (0..count).map(|idx| &*self.base.add(idx))
    }

    unsafe fn iter_mut(&mut self, count: usize) -> impl Iterator<Item = &mut Question> {
        (0..count).map(|idx| &mut *self.base.add(idx))
    }

    unsafe fn free_contents(&mut self, count: usize) {
        let msgs = slice::from_raw_parts_mut(self.base, count);
        for msg in msgs {
            msg.clear()
        }
        memory::free(self.base);
        self.base = ptr::null_mut()
    }
}

/// The Linux version of the indirection layer between Question and Questions.
#[repr(transparent)]
pub struct LinuxPamIndirect {
    base: [*mut Question; 0],
    _marker: Immovable,
}

impl IndirectTrait for LinuxPamIndirect {
    fn alloc(count: usize) -> *mut Self {
        // SAFETY: We're only allocating, and when we're done,
        // everything will be in a known-good state.
        let me_ptr: *mut Self = memory::calloc::<*mut Question>(count).cast();
        unsafe {
            let me = &mut *me_ptr;
            let ptr_list = slice::from_raw_parts_mut(me.base.as_mut_ptr(), count);
            for entry in ptr_list {
                *entry = memory::calloc(1);
            }
        }
        me_ptr
    }

    unsafe fn iter(&self, count: usize) -> impl Iterator<Item = &Question> {
        (0..count).map(|idx| &**self.base.as_ptr().add(idx))
    }

    unsafe fn iter_mut(&mut self, count: usize) -> impl Iterator<Item = &mut Question> {
        (0..count).map(|idx| &mut **self.base.as_mut_ptr().add(idx))
    }

    unsafe fn free_contents(&mut self, count: usize) {
        let msgs = slice::from_raw_parts_mut(self.base.as_mut_ptr(), count);
        for msg in msgs {
            if let Some(msg) = msg.as_mut() {
                msg.clear();
            }
            memory::free(*msg);
            *msg = ptr::null_mut();
        }
    }
}

impl Default for Question {
    fn default() -> Self {
        Self {
            style: Default::default(),
            data: ptr::null_mut(),
            _marker: Default::default(),
        }
    }
}

impl Question {
    /// Replaces the contents of this question with the question
    /// from the message.
    pub fn fill(&mut self, msg: &Message) -> Result<()> {
        let (style, data) = copy_to_heap(msg)?;
        self.clear();
        self.style = style.into();
        self.data = data;
        Ok(())
    }

    /// Gets this message's data pointer as a string.
    ///
    /// # Safety
    ///
    /// It's up to you to pass this only on types with a string value.
    unsafe fn string_data(&self) -> Result<&str> {
        if self.data.is_null() {
            Ok("")
        } else {
            CStr::from_ptr(self.data.cast())
                .to_str()
                .map_err(|_| ErrorCode::ConversationError)
        }
    }

    /// Gets this message's data pointer as borrowed binary data.
    unsafe fn binary_data(&self) -> (&[u8], u8) {
        self.data
            .cast::<CBinaryData>()
            .as_ref()
            .map(Into::into)
            .unwrap_or_default()
    }

    /// Zeroes out the data stored here.
    fn clear(&mut self) {
        // SAFETY: We either created this data or we got it from PAM.
        // After this function is done, it will be zeroed out.
        unsafe {
            if let Ok(style) = Style::try_from(self.style) {
                match style {
                    Style::BinaryPrompt => {
                        if let Some(d) = self.data.cast::<CBinaryData>().as_mut() {
                            d.zero_contents()
                        }
                    }
                    Style::TextInfo
                    | Style::RadioType
                    | Style::ErrorMsg
                    | Style::PromptEchoOff
                    | Style::PromptEchoOn => memory::zero_c_string(self.data.cast()),
                }
            };
            memory::free(self.data);
            self.data = ptr::null_mut();
        }
    }
}

impl<'a> TryFrom<&'a Question> for OwnedMessage<'a> {
    type Error = ErrorCode;
    fn try_from(question: &'a Question) -> Result<Self> {
        let style: Style = question
            .style
            .try_into()
            .map_err(|_| ErrorCode::ConversationError)?;
        // SAFETY: In all cases below, we're matching the
        let prompt = unsafe {
            match style {
                Style::PromptEchoOff => {
                    Self::MaskedPrompt(MaskedQAndA::new(question.string_data()?))
                }
                Style::PromptEchoOn => Self::Prompt(QAndA::new(question.string_data()?)),
                Style::ErrorMsg => Self::Error(ErrorMsg::new(question.string_data()?)),
                Style::TextInfo => Self::Info(InfoMsg::new(question.string_data()?)),
                Style::RadioType => Self::RadioPrompt(RadioQAndA::new(question.string_data()?)),
                Style::BinaryPrompt => Self::BinaryPrompt(BinaryQAndA::new(question.binary_data())),
            }
        };
        Ok(prompt)
    }
}

/// Copies the contents of this message to the C heap.
fn copy_to_heap(msg: &Message) -> Result<(Style, *mut c_void)> {
    let alloc = |style, text| Ok((style, memory::malloc_str(text)?.cast()));
    match *msg {
        Message::MaskedPrompt(p) => alloc(Style::PromptEchoOff, p.question()),
        Message::Prompt(p) => alloc(Style::PromptEchoOn, p.question()),
        Message::RadioPrompt(p) => alloc(Style::RadioType, p.question()),
        Message::Error(p) => alloc(Style::ErrorMsg, p.question()),
        Message::Info(p) => alloc(Style::TextInfo, p.question()),
        Message::BinaryPrompt(p) => Ok((
            Style::BinaryPrompt,
            CBinaryData::alloc(p.question())?.cast(),
        )),
    }
}

#[cfg(test)]
mod tests {

    use super::{
        BinaryQAndA, ErrorMsg, GenericQuestions, IndirectTrait, InfoMsg, LinuxPamIndirect,
        MaskedQAndA, OwnedMessage, QAndA, RadioQAndA, Result, StandardIndirect,
    };

    macro_rules! assert_matches {
        ($id:ident => $variant:path, $q:expr) => {
            if let $variant($id) = $id {
                assert_eq!($q, $id.question());
            } else {
                panic!("mismatched enum variant {x:?}", x = $id);
            }
        };
    }

    macro_rules! tests { ($fn_name:ident<$typ:ident>) => {
        #[test]
        fn $fn_name() {
            let interrogation = GenericQuestions::<$typ>::new(&[
                MaskedQAndA::new("hocus pocus").message(),
                BinaryQAndA::new((&[5, 4, 3, 2, 1], 66)).message(),
                QAndA::new("what").message(),
                QAndA::new("who").message(),
                InfoMsg::new("hey").message(),
                ErrorMsg::new("gasp").message(),
                RadioQAndA::new("you must choose").message(),
            ])
            .unwrap();
            let indirect = interrogation.indirect();

            let remade = unsafe { $typ::borrow_ptr(indirect) }.unwrap();
            let messages: Vec<OwnedMessage> = unsafe { remade.iter(interrogation.count) }
                .map(TryInto::try_into)
                .collect::<Result<_>>()
                .unwrap();
            let [masked, bin, what, who, hey, gasp, choose] = messages.try_into().unwrap();
            assert_matches!(masked => OwnedMessage::MaskedPrompt, "hocus pocus");
            assert_matches!(bin => OwnedMessage::BinaryPrompt, (&[5, 4, 3, 2, 1][..], 66));
            assert_matches!(what => OwnedMessage::Prompt, "what");
            assert_matches!(who => OwnedMessage::Prompt, "who");
            assert_matches!(hey => OwnedMessage::Info, "hey");
            assert_matches!(gasp => OwnedMessage::Error, "gasp");
            assert_matches!(choose => OwnedMessage::RadioPrompt, "you must choose");
        }
    }}

    tests!(test_xsso<StandardIndirect>);
    tests!(test_linux<LinuxPamIndirect>);
}