diff src/libpam/conversation.rs @ 77:351bdc13005e

Update the libpam module to work with the new structure.
author Paul Fisher <paul@pfish.zone>
date Sun, 08 Jun 2025 01:03:46 -0400
parents c30811b4afae
children 002adfb98c5c
line wrap: on
line diff
--- a/src/libpam/conversation.rs	Sat Jun 07 18:55:27 2025 -0400
+++ b/src/libpam/conversation.rs	Sun Jun 08 01:03:46 2025 -0400
@@ -1,14 +1,15 @@
-use crate::constants::Result;
-use crate::conv::{Conversation, Message, Response};
+use crate::conv::{
+    BinaryQAndA, Conversation, ErrorMsg, InfoMsg, MaskedQAndA, Message, QAndA,
+    RadioQAndA,
+};
 use crate::libpam::memory::Immovable;
-use crate::libpam::message::{MessageIndirector, OwnedMessages};
-use crate::libpam::response::{OwnedResponses, RawBinaryResponse, RawResponse, RawTextResponse};
+use crate::libpam::message::{Indirect, Questions};
+use crate::libpam::response::{Answer, Answers, BinaryAnswer, TextAnswer};
 use crate::ErrorCode;
-use crate::ErrorCode::ConversationError;
+use crate::Result;
 use std::ffi::c_int;
 use std::iter;
 use std::marker::PhantomData;
-use std::result::Result as StdResult;
 
 /// An opaque structure that is passed through PAM in a conversation.
 #[repr(C)]
@@ -23,15 +24,15 @@
 /// - `messages` is a pointer to the messages being sent to the user.
 ///   For details about its structure, see the documentation of
 ///   [`OwnedMessages`](super::OwnedMessages).
-/// - `responses` is a pointer to an array of [`RawResponse`]s,
+/// - `responses` is a pointer to an array of [`Answer`]s,
 ///   which PAM sets in response to a module's request.
 ///   This is an array of structs, not an array of pointers to a struct.
 ///   There should always be exactly as many `responses` as `num_msg`.
 /// - `appdata` is the `appdata` field of the [`LibPamConversation`] we were passed.
 pub type ConversationCallback = unsafe extern "C" fn(
     num_msg: c_int,
-    messages: *const MessageIndirector,
-    responses: *mut *mut RawResponse,
+    messages: *const Indirect,
+    responses: *mut *mut Answer,
     appdata: *mut AppData,
 ) -> c_int;
 
@@ -59,104 +60,116 @@
 
     unsafe extern "C" fn wrapper_callback<C: Conversation>(
         count: c_int,
-        messages: *const MessageIndirector,
-        responses: *mut *mut RawResponse,
+        questions: *const Indirect,
+        answers: *mut *mut Answer,
         me: *mut AppData,
     ) -> c_int {
-        let call = || {
+        let internal = || {
+            // Collect all our pointers
             let conv = me
                 .cast::<C>()
                 .as_mut()
                 .ok_or(ErrorCode::ConversationError)?;
-            let indir = messages.as_ref().ok_or(ErrorCode::ConversationError)?;
-            let response_ptr = responses.as_mut().ok_or(ErrorCode::ConversationError)?;
-            let messages: Vec<Message> = indir
+            let indirect = questions.as_ref().ok_or(ErrorCode::ConversationError)?;
+            let answers_ptr = answers.as_mut().ok_or(ErrorCode::ConversationError)?;
+
+            // Build our owned list of Q&As from the questions we've been asked
+            let messages: Vec<OwnedMessage> = indirect
                 .iter(count as usize)
-                .map(Message::try_from)
-                .collect::<StdResult<_, _>>()
+                .map(OwnedMessage::try_from)
+                .collect::<Result<_>>()
                 .map_err(|_| ErrorCode::ConversationError)?;
-            let responses = conv.communicate(&messages)?;
-            let owned =
-                OwnedResponses::build(&responses).map_err(|_| ErrorCode::ConversationError)?;
-            *response_ptr = owned.into_ptr();
+            // Borrow all those Q&As and ask them
+            let borrowed: Vec<Message> = messages.iter().map(Into::into).collect();
+            conv.communicate(&borrowed);
+
+            // Send our answers back
+            let owned = Answers::build(messages).map_err(|_| ErrorCode::ConversationError)?;
+            *answers_ptr = owned.into_ptr();
             Ok(())
         };
-        ErrorCode::result_to_c(call())
+        ErrorCode::result_to_c(internal())
     }
 }
 
 impl Conversation for LibPamConversation<'_> {
-    fn communicate(&mut self, messages: &[Message]) -> Result<Vec<Response>> {
-        let mut msgs_to_send = OwnedMessages::alloc(messages.len());
-        for (dst, src) in iter::zip(msgs_to_send.iter_mut(), messages.iter()) {
-            dst.set(*src).map_err(|_| ErrorCode::ConversationError)?
+    fn communicate(&mut self, messages: &[Message]) {
+        let internal = || {
+            let mut msgs_to_send = Questions::alloc(messages.len());
+            for (dst, src) in iter::zip(msgs_to_send.iter_mut(), messages.iter()) {
+                dst.fill(src).map_err(|_| ErrorCode::ConversationError)?
+            }
+            let mut response_pointer = std::ptr::null_mut();
+            // SAFETY: We're calling into PAM with valid everything.
+            let result = unsafe {
+                (self.callback)(
+                    messages.len() as c_int,
+                    msgs_to_send.indirect(),
+                    &mut response_pointer,
+                    self.appdata,
+                )
+            };
+            ErrorCode::result_from(result)?;
+            // SAFETY: This is a pointer we just got back from PAM.
+            // We have to trust that the responses from PAM match up
+            // with the questions we sent.
+            unsafe {
+                let mut owned_responses = Answers::from_c_heap(response_pointer, messages.len());
+                for (msg, response) in iter::zip(messages, owned_responses.iter_mut()) {
+                    convert(msg, response);
+                }
+            };
+            Ok(())
+        };
+        if let Err(e) = internal() {
+            messages.iter().for_each(|m| m.set_error(e))
         }
-        let mut response_pointer = std::ptr::null_mut();
-        // SAFETY: We're calling into PAM with valid everything.
-        let result = unsafe {
-            (self.callback)(
-                messages.len() as c_int,
-                msgs_to_send.indirector(),
-                &mut response_pointer,
-                self.appdata,
-            )
-        };
-        ErrorCode::result_from(result)?;
-        // SAFETY: This is a pointer we just got back from PAM.
-        let owned_responses =
-            unsafe { OwnedResponses::from_c_heap(response_pointer, messages.len()) };
-        convert_responses(messages, owned_responses)
     }
 }
 
-fn convert_responses(
-    messages: &[Message],
-    mut raw_responses: OwnedResponses,
-) -> Result<Vec<Response>> {
-    let pairs = iter::zip(messages.iter(), raw_responses.iter_mut());
-    // We first collect into a Vec of Results so that we always process
-    // every single entry, which may involve freeing it.
-    let responses: Vec<_> = pairs.map(convert).collect();
-    // Only then do we return the first error, if present.
-    responses.into_iter().collect()
+/// Like [`Message`], but this time we own the contents.
+pub enum OwnedMessage<'a> {
+    MaskedPrompt(MaskedQAndA<'a>),
+    Prompt(QAndA<'a>),
+    RadioPrompt(RadioQAndA<'a>),
+    BinaryPrompt(BinaryQAndA<'a>),
+    Info(InfoMsg<'a>),
+    Error(ErrorMsg<'a>),
+}
+
+impl<'a> From<&'a OwnedMessage<'a>> for Message<'a> {
+    fn from(src: &'a OwnedMessage) -> Self {
+        match src {
+            OwnedMessage::MaskedPrompt(m) => Message::MaskedPrompt(m),
+            OwnedMessage::Prompt(m) => Message::Prompt(m),
+            OwnedMessage::RadioPrompt(m) => Message::RadioPrompt(m),
+            OwnedMessage::BinaryPrompt(m) => Message::BinaryPrompt(m),
+            OwnedMessage::Info(m) => Message::Info(m),
+            OwnedMessage::Error(m) => Message::Error(m),
+        }
+    }
 }
 
-/// Converts one message-to-raw pair to a Response.
-fn convert((sent, received): (&Message, &mut RawResponse)) -> Result<Response> {
-    Ok(match sent {
-        Message::MaskedPrompt(_) => {
-            // SAFETY: Since this is a response to a text message,
-            // we know it is text.
-            let text_resp = unsafe { RawTextResponse::upcast(received) };
-            let ret = Response::MaskedText(
-                text_resp
-                    .contents()
-                    .map_err(|_| ErrorCode::ConversationError)?
-                    .into(),
-            );
-            // SAFETY: We're the only ones using this,
-            // and we haven't freed it.
-            text_resp.free_contents();
-            ret
+/// Fills in the answer of the Message with the given response.
+///
+/// # Safety
+///
+/// You are responsible for ensuring that the src-dst pair matches.
+unsafe fn convert(msg: &Message, resp: &mut Answer) {
+    macro_rules! fill_text {
+    ($dst:ident, $src:ident) => {{let text_resp = unsafe {TextAnswer::upcast($src)};
+    $dst.set_answer(text_resp.contents().map(Into::into));}}
+}
+    match *msg {
+        Message::MaskedPrompt(qa) => fill_text!(qa, resp),
+        Message::Prompt(qa) => fill_text!(qa, resp),
+        Message::RadioPrompt(qa) => fill_text!(qa, resp),
+        Message::Error(m) => m.set_answer(Ok(())),
+        Message::Info(m) => m.set_answer(Ok(())),
+        Message::BinaryPrompt(qa) => {
+            let bin_resp = unsafe { BinaryAnswer::upcast(resp) };
+            qa.set_answer(Ok(bin_resp.data().into()));
+            bin_resp.zero_contents()
         }
-        Message::Prompt(_) | Message::RadioPrompt(_) => {
-            // SAFETY: Since this is a response to a text message,
-            // we know it is text.
-            let text_resp = unsafe { RawTextResponse::upcast(received) };
-            let ret = Response::Text(text_resp.contents().map_err(|_| ConversationError)?.into());
-            // SAFETY: We're the only ones using this,
-            // and we haven't freed it.
-            text_resp.free_contents();
-            ret
-        }
-        Message::ErrorMsg(_) | Message::InfoMsg(_) => Response::NoResponse,
-        Message::BinaryPrompt { .. } => {
-            let bin_resp = unsafe { RawBinaryResponse::upcast(received) };
-            let ret = Response::Binary(bin_resp.to_owned());
-            // SAFETY: We're the only ones using this,
-            // and we haven't freed it.
-            bin_resp.free_contents();
-            ret
-        }
-    })
+    }
 }