view src/pam_ffi/conversation.rs @ 74:c7c596e6388f

Make conversations type-safe (last big reorg) (REAL) (NOT CLICKBAIT) In previous versions of Conversation, you could send messages and then return messages of the wrong type or in the wrong order or whatever. The receiver would then have to make sure that there were the right number of messages and that each message was the right type. That's annoying. This change makes the `Message` enum a two-way channel, where the asker puts their question into it, and then the answerer (the conversation) puts the answer in and returns control to the asker. The asker then only has to pull the Answer of the type they wanted out of the message.
author Paul Fisher <paul@pfish.zone>
date Fri, 06 Jun 2025 22:21:17 -0400
parents ac6881304c78
children
line wrap: on
line source

use crate::constants::Result;
use crate::conv::{Conversation, Message, Response};
use crate::pam_ffi::memory::Immovable;
use crate::pam_ffi::message::{MessageIndirector, OwnedMessages};
use crate::pam_ffi::response::{OwnedResponses, RawBinaryResponse, RawResponse, RawTextResponse};
use crate::ErrorCode;
use crate::ErrorCode::ConversationError;
use std::ffi::c_int;
use std::iter;
use std::marker::PhantomData;
use std::result::Result as StdResult;

/// An opaque structure that is passed through PAM in a conversation.
#[repr(C)]
pub struct AppData {
    _data: (),
    _marker: Immovable,
}

/// The callback that PAM uses to get information in a conversation.
///
/// - `num_msg` is the number of messages in the `pam_message` array.
/// - `messages` is a pointer to the messages being sent to the user.
///   For details about its structure, see the documentation of
///   [`OwnedMessages`](super::OwnedMessages).
/// - `responses` is a pointer to an array of [`RawResponse`]s,
///   which PAM sets in response to a module's request.
///   This is an array of structs, not an array of pointers to a struct.
///   There should always be exactly as many `responses` as `num_msg`.
/// - `appdata` is the `appdata` field of the [`LibPamConversation`] we were passed.
pub type ConversationCallback = unsafe extern "C" fn(
    num_msg: c_int,
    messages: *const MessageIndirector,
    responses: *mut *mut RawResponse,
    appdata: *mut AppData,
) -> c_int;

/// The type used by PAM to call back into a conversation.
#[repr(C)]
pub struct LibPamConversation<'a> {
    /// The function that is called to get information from the user.
    callback: ConversationCallback,
    /// The pointer that will be passed as the last parameter
    /// to the conversation callback.
    appdata: *mut AppData,
    life: PhantomData<&'a mut ()>,
    _marker: Immovable,
}

impl LibPamConversation<'_> {
    fn wrap<C: Conversation>(conv: &mut C) -> Self {
        Self {
            callback: Self::wrapper_callback::<C>,
            appdata: (conv as *mut C).cast(),
            life: PhantomData,
            _marker: Immovable(PhantomData),
        }
    }

    unsafe extern "C" fn wrapper_callback<C: Conversation>(
        count: c_int,
        messages: *const MessageIndirector,
        responses: *mut *mut RawResponse,
        me: *mut AppData,
    ) -> c_int {
        let call = || {
            let conv = me
                .cast::<C>()
                .as_mut()
                .ok_or(ErrorCode::ConversationError)?;
            let indir = messages.as_ref().ok_or(ErrorCode::ConversationError)?;
            let response_ptr = responses.as_mut().ok_or(ErrorCode::ConversationError)?;
            let messages: Vec<Message> = indir
                .iter(count as usize)
                .map(Message::try_from)
                .collect::<StdResult<_, _>>()
                .map_err(|_| ErrorCode::ConversationError)?;
            let responses = conv.communicate(&messages)?;
            let owned =
                OwnedResponses::build(&responses).map_err(|_| ErrorCode::ConversationError)?;
            *response_ptr = owned.into_ptr();
            Ok(())
        };
        ErrorCode::result_to_c(call())
    }
}

impl Conversation for LibPamConversation<'_> {
    fn communicate(&mut self, messages: &[Message]) -> Result<Vec<Response>> {
        let mut msgs_to_send = OwnedMessages::alloc(messages.len());
        for (dst, src) in iter::zip(msgs_to_send.iter_mut(), messages.iter()) {
            dst.set(*src).map_err(|_| ErrorCode::ConversationError)?
        }
        let mut response_pointer = std::ptr::null_mut();
        // SAFETY: We're calling into PAM with valid everything.
        let result = unsafe {
            (self.callback)(
                messages.len() as c_int,
                msgs_to_send.indirector(),
                &mut response_pointer,
                self.appdata,
            )
        };
        ErrorCode::result_from(result)?;
        // SAFETY: This is a pointer we just got back from PAM.
        let owned_responses =
            unsafe { OwnedResponses::from_c_heap(response_pointer, messages.len()) };
        convert_responses(messages, owned_responses)
    }
}

fn convert_responses(
    messages: &[Message],
    mut raw_responses: OwnedResponses,
) -> Result<Vec<Response>> {
    let pairs = iter::zip(messages.iter(), raw_responses.iter_mut());
    // We first collect into a Vec of Results so that we always process
    // every single entry, which may involve freeing it.
    let responses: Vec<_> = pairs.map(convert).collect();
    // Only then do we return the first error, if present.
    responses.into_iter().collect()
}

/// Converts one message-to-raw pair to a Response.
fn convert((sent, received): (&Message, &mut RawResponse)) -> Result<Response> {
    Ok(match sent {
        Message::MaskedPrompt(_) => {
            // SAFETY: Since this is a response to a text message,
            // we know it is text.
            let text_resp = unsafe { RawTextResponse::upcast(received) };
            let ret = Response::MaskedText(
                text_resp
                    .contents()
                    .map_err(|_| ErrorCode::ConversationError)?
                    .into(),
            );
            // SAFETY: We're the only ones using this,
            // and we haven't freed it.
            text_resp.free_contents();
            ret
        }
        Message::Prompt(_) | Message::RadioPrompt(_) => {
            // SAFETY: Since this is a response to a text message,
            // we know it is text.
            let text_resp = unsafe { RawTextResponse::upcast(received) };
            let ret = Response::Text(text_resp.contents().map_err(|_| ConversationError)?.into());
            // SAFETY: We're the only ones using this,
            // and we haven't freed it.
            text_resp.free_contents();
            ret
        }
        Message::ErrorMsg(_) | Message::InfoMsg(_) => Response::NoResponse,
        Message::BinaryPrompt { .. } => {
            let bin_resp = unsafe { RawBinaryResponse::upcast(received) };
            let ret = Response::Binary(bin_resp.to_owned());
            // SAFETY: We're the only ones using this,
            // and we haven't freed it.
            bin_resp.free_contents();
            ret
        }
    })
}