view src/libpam/conversation.rs @ 78:002adfb98c5c

Rename files, reorder structs, remove annoying BorrowedBinaryData type. This is basically a cleanup change. Also it adds tests. - Renames the files with Questions and Answers to question and answer. - Reorders the structs in those files to put the important ones first. - Removes the BorrowedBinaryData type. It was a bad idea all along. Instead, we just use (&[u8], u8). - Adds some tests because I just can't help myself.
author Paul Fisher <paul@pfish.zone>
date Sun, 08 Jun 2025 03:48:40 -0400
parents 351bdc13005e
children 5aa1a010f1e8
line wrap: on
line source

use crate::conv::{
    BinaryQAndA, Conversation, ErrorMsg, InfoMsg, MaskedQAndA, Message, QAndA, RadioQAndA,
};
use crate::libpam::answer::{Answer, Answers, BinaryAnswer, TextAnswer};
use crate::libpam::memory::Immovable;
use crate::libpam::question::{Indirect, Questions};
use crate::ErrorCode;
use crate::Result;
use std::ffi::c_int;
use std::iter;
use std::marker::PhantomData;

/// An opaque structure that is passed through PAM in a conversation.
#[repr(C)]
pub struct AppData {
    _data: (),
    _marker: Immovable,
}

/// The callback that PAM uses to get information in a conversation.
///
/// - `num_msg` is the number of messages in the `pam_message` array.
/// - `messages` is a pointer to the messages being sent to the user.
///   For details about its structure, see the documentation of
///   [`OwnedMessages`](super::OwnedMessages).
/// - `responses` is a pointer to an array of [`Answer`]s,
///   which PAM sets in response to a module's request.
///   This is an array of structs, not an array of pointers to a struct.
///   There should always be exactly as many `responses` as `num_msg`.
/// - `appdata` is the `appdata` field of the [`LibPamConversation`] we were passed.
pub type ConversationCallback = unsafe extern "C" fn(
    num_msg: c_int,
    messages: *const Indirect,
    responses: *mut *mut Answer,
    appdata: *mut AppData,
) -> c_int;

/// The type used by PAM to call back into a conversation.
#[repr(C)]
pub struct LibPamConversation<'a> {
    /// The function that is called to get information from the user.
    callback: ConversationCallback,
    /// The pointer that will be passed as the last parameter
    /// to the conversation callback.
    appdata: *mut AppData,
    life: PhantomData<&'a mut ()>,
    _marker: Immovable,
}

impl LibPamConversation<'_> {
    fn wrap<C: Conversation>(conv: &mut C) -> Self {
        Self {
            callback: Self::wrapper_callback::<C>,
            appdata: (conv as *mut C).cast(),
            life: PhantomData,
            _marker: Immovable(PhantomData),
        }
    }

    unsafe extern "C" fn wrapper_callback<C: Conversation>(
        count: c_int,
        questions: *const Indirect,
        answers: *mut *mut Answer,
        me: *mut AppData,
    ) -> c_int {
        let internal = || {
            // Collect all our pointers
            let conv = me
                .cast::<C>()
                .as_mut()
                .ok_or(ErrorCode::ConversationError)?;
            let indirect = questions.as_ref().ok_or(ErrorCode::ConversationError)?;
            let answers_ptr = answers.as_mut().ok_or(ErrorCode::ConversationError)?;

            // Build our owned list of Q&As from the questions we've been asked
            let messages: Vec<OwnedMessage> = indirect
                .iter(count as usize)
                .map(TryInto::try_into)
                .collect::<Result<_>>()
                .map_err(|_| ErrorCode::ConversationError)?;
            // Borrow all those Q&As and ask them
            let borrowed: Vec<Message> = messages.iter().map(Into::into).collect();
            conv.communicate(&borrowed);

            // Send our answers back
            let owned = Answers::build(messages).map_err(|_| ErrorCode::ConversationError)?;
            *answers_ptr = owned.into_ptr();
            Ok(())
        };
        ErrorCode::result_to_c(internal())
    }
}

impl Conversation for LibPamConversation<'_> {
    fn communicate(&mut self, messages: &[Message]) {
        let internal = || {
            let questions = Questions::new(messages)?;
            let mut response_pointer = std::ptr::null_mut();
            // SAFETY: We're calling into PAM with valid everything.
            let result = unsafe {
                (self.callback)(
                    messages.len() as c_int,
                    questions.indirect(),
                    &mut response_pointer,
                    self.appdata,
                )
            };
            ErrorCode::result_from(result)?;
            // SAFETY: This is a pointer we just got back from PAM.
            // We have to trust that the responses from PAM match up
            // with the questions we sent.
            unsafe {
                let mut owned_responses = Answers::from_c_heap(response_pointer, messages.len());
                for (msg, response) in iter::zip(messages, owned_responses.iter_mut()) {
                    convert(msg, response);
                }
            };
            Ok(())
        };
        if let Err(e) = internal() {
            messages.iter().for_each(|m| m.set_error(e))
        }
    }
}

/// Like [`Message`], but this time we own the contents.
#[derive(Debug)]
pub enum OwnedMessage<'a> {
    MaskedPrompt(MaskedQAndA<'a>),
    Prompt(QAndA<'a>),
    RadioPrompt(RadioQAndA<'a>),
    BinaryPrompt(BinaryQAndA<'a>),
    Info(InfoMsg<'a>),
    Error(ErrorMsg<'a>),
}

impl<'a> From<&'a OwnedMessage<'a>> for Message<'a> {
    fn from(src: &'a OwnedMessage) -> Self {
        match src {
            OwnedMessage::MaskedPrompt(m) => Message::MaskedPrompt(m),
            OwnedMessage::Prompt(m) => Message::Prompt(m),
            OwnedMessage::RadioPrompt(m) => Message::RadioPrompt(m),
            OwnedMessage::BinaryPrompt(m) => Message::BinaryPrompt(m),
            OwnedMessage::Info(m) => Message::Info(m),
            OwnedMessage::Error(m) => Message::Error(m),
        }
    }
}

/// Fills in the answer of the Message with the given response.
///
/// # Safety
///
/// You are responsible for ensuring that the src-dst pair matches.
unsafe fn convert(msg: &Message, resp: &mut Answer) {
    macro_rules! fill_text {
        ($dst:ident, $src:ident) => {{
            let text_resp = unsafe { TextAnswer::upcast($src) };
            $dst.set_answer(text_resp.contents().map(Into::into));
        }};
    }
    match *msg {
        Message::MaskedPrompt(qa) => fill_text!(qa, resp),
        Message::Prompt(qa) => fill_text!(qa, resp),
        Message::RadioPrompt(qa) => fill_text!(qa, resp),
        Message::Error(m) => m.set_answer(Ok(())),
        Message::Info(m) => m.set_answer(Ok(())),
        Message::BinaryPrompt(qa) => {
            let bin_resp = unsafe { BinaryAnswer::upcast(resp) };
            qa.set_answer(Ok(bin_resp.data().into()));
            bin_resp.zero_contents()
        }
    }
}