view src/libpam/conversation.rs @ 75:c30811b4afae

rename pam_ffi submodule to libpam.
author Paul Fisher <paul@pfish.zone>
date Fri, 06 Jun 2025 22:35:08 -0400
parents src/pam_ffi/conversation.rs@c7c596e6388f
children 351bdc13005e
line wrap: on
line source

use crate::constants::Result;
use crate::conv::{Conversation, Message, Response};
use crate::libpam::memory::Immovable;
use crate::libpam::message::{MessageIndirector, OwnedMessages};
use crate::libpam::response::{OwnedResponses, RawBinaryResponse, RawResponse, RawTextResponse};
use crate::ErrorCode;
use crate::ErrorCode::ConversationError;
use std::ffi::c_int;
use std::iter;
use std::marker::PhantomData;
use std::result::Result as StdResult;

/// An opaque structure that is passed through PAM in a conversation.
#[repr(C)]
pub struct AppData {
    _data: (),
    _marker: Immovable,
}

/// The callback that PAM uses to get information in a conversation.
///
/// - `num_msg` is the number of messages in the `pam_message` array.
/// - `messages` is a pointer to the messages being sent to the user.
///   For details about its structure, see the documentation of
///   [`OwnedMessages`](super::OwnedMessages).
/// - `responses` is a pointer to an array of [`RawResponse`]s,
///   which PAM sets in response to a module's request.
///   This is an array of structs, not an array of pointers to a struct.
///   There should always be exactly as many `responses` as `num_msg`.
/// - `appdata` is the `appdata` field of the [`LibPamConversation`] we were passed.
pub type ConversationCallback = unsafe extern "C" fn(
    num_msg: c_int,
    messages: *const MessageIndirector,
    responses: *mut *mut RawResponse,
    appdata: *mut AppData,
) -> c_int;

/// The type used by PAM to call back into a conversation.
#[repr(C)]
pub struct LibPamConversation<'a> {
    /// The function that is called to get information from the user.
    callback: ConversationCallback,
    /// The pointer that will be passed as the last parameter
    /// to the conversation callback.
    appdata: *mut AppData,
    life: PhantomData<&'a mut ()>,
    _marker: Immovable,
}

impl LibPamConversation<'_> {
    fn wrap<C: Conversation>(conv: &mut C) -> Self {
        Self {
            callback: Self::wrapper_callback::<C>,
            appdata: (conv as *mut C).cast(),
            life: PhantomData,
            _marker: Immovable(PhantomData),
        }
    }

    unsafe extern "C" fn wrapper_callback<C: Conversation>(
        count: c_int,
        messages: *const MessageIndirector,
        responses: *mut *mut RawResponse,
        me: *mut AppData,
    ) -> c_int {
        let call = || {
            let conv = me
                .cast::<C>()
                .as_mut()
                .ok_or(ErrorCode::ConversationError)?;
            let indir = messages.as_ref().ok_or(ErrorCode::ConversationError)?;
            let response_ptr = responses.as_mut().ok_or(ErrorCode::ConversationError)?;
            let messages: Vec<Message> = indir
                .iter(count as usize)
                .map(Message::try_from)
                .collect::<StdResult<_, _>>()
                .map_err(|_| ErrorCode::ConversationError)?;
            let responses = conv.communicate(&messages)?;
            let owned =
                OwnedResponses::build(&responses).map_err(|_| ErrorCode::ConversationError)?;
            *response_ptr = owned.into_ptr();
            Ok(())
        };
        ErrorCode::result_to_c(call())
    }
}

impl Conversation for LibPamConversation<'_> {
    fn communicate(&mut self, messages: &[Message]) -> Result<Vec<Response>> {
        let mut msgs_to_send = OwnedMessages::alloc(messages.len());
        for (dst, src) in iter::zip(msgs_to_send.iter_mut(), messages.iter()) {
            dst.set(*src).map_err(|_| ErrorCode::ConversationError)?
        }
        let mut response_pointer = std::ptr::null_mut();
        // SAFETY: We're calling into PAM with valid everything.
        let result = unsafe {
            (self.callback)(
                messages.len() as c_int,
                msgs_to_send.indirector(),
                &mut response_pointer,
                self.appdata,
            )
        };
        ErrorCode::result_from(result)?;
        // SAFETY: This is a pointer we just got back from PAM.
        let owned_responses =
            unsafe { OwnedResponses::from_c_heap(response_pointer, messages.len()) };
        convert_responses(messages, owned_responses)
    }
}

fn convert_responses(
    messages: &[Message],
    mut raw_responses: OwnedResponses,
) -> Result<Vec<Response>> {
    let pairs = iter::zip(messages.iter(), raw_responses.iter_mut());
    // We first collect into a Vec of Results so that we always process
    // every single entry, which may involve freeing it.
    let responses: Vec<_> = pairs.map(convert).collect();
    // Only then do we return the first error, if present.
    responses.into_iter().collect()
}

/// Converts one message-to-raw pair to a Response.
fn convert((sent, received): (&Message, &mut RawResponse)) -> Result<Response> {
    Ok(match sent {
        Message::MaskedPrompt(_) => {
            // SAFETY: Since this is a response to a text message,
            // we know it is text.
            let text_resp = unsafe { RawTextResponse::upcast(received) };
            let ret = Response::MaskedText(
                text_resp
                    .contents()
                    .map_err(|_| ErrorCode::ConversationError)?
                    .into(),
            );
            // SAFETY: We're the only ones using this,
            // and we haven't freed it.
            text_resp.free_contents();
            ret
        }
        Message::Prompt(_) | Message::RadioPrompt(_) => {
            // SAFETY: Since this is a response to a text message,
            // we know it is text.
            let text_resp = unsafe { RawTextResponse::upcast(received) };
            let ret = Response::Text(text_resp.contents().map_err(|_| ConversationError)?.into());
            // SAFETY: We're the only ones using this,
            // and we haven't freed it.
            text_resp.free_contents();
            ret
        }
        Message::ErrorMsg(_) | Message::InfoMsg(_) => Response::NoResponse,
        Message::BinaryPrompt { .. } => {
            let bin_resp = unsafe { RawBinaryResponse::upcast(received) };
            let ret = Response::Binary(bin_resp.to_owned());
            // SAFETY: We're the only ones using this,
            // and we haven't freed it.
            bin_resp.free_contents();
            ret
        }
    })
}