Mercurial > crates > nonstick
diff src/module.rs @ 70:9f8381a1c09c
Implement low-level conversation primitives.
This change does two primary things:
1. Introduces new Conversation traits, to be implemented both
by the library and by PAM client applications.
2. Builds the memory-management infrastructure for passing messages
through the conversation.
...and it adds tests for both of the above, including ASAN tests.
author | Paul Fisher <paul@pfish.zone> |
---|---|
date | Tue, 03 Jun 2025 01:21:59 -0400 |
parents | a674799a5cd3 |
children | 58f9d2a4df38 |
line wrap: on
line diff
--- a/src/module.rs Sun Jun 01 01:15:04 2025 -0400 +++ b/src/module.rs Tue Jun 03 01:21:59 2025 -0400 @@ -1,7 +1,13 @@ //! Functions and types useful for implementing a PAM module. +// Temporarily allowed until we get the actual conversation functions hooked up. +#![allow(dead_code)] + use crate::constants::{ErrorCode, Flags, Result}; +use crate::conv::BinaryData; +use crate::conv::{Conversation, Message, Response}; use crate::handle::PamModuleHandle; +use secure_string::SecureString; use std::ffi::CStr; /// A trait for a PAM module to implement. @@ -233,6 +239,81 @@ } } +/// Provides methods to make it easier to send exactly one message. +/// +/// This is primarily used by PAM modules, so that a module that only needs +/// one piece of information at a time doesn't have a ton of boilerplate. +/// You may also find it useful for testing PAM application libraries. +/// +/// ``` +/// # use nonstick::Result; +/// # use nonstick::conv::Conversation; +/// # use nonstick::module::ConversationMux; +/// # fn _do_test(conv: impl Conversation) -> Result<()> { +/// let mut mux = ConversationMux(conv); +/// let token = mux.masked_prompt("enter your one-time token")?; +/// # Ok(()) +/// # } +pub struct ConversationMux<C: Conversation>(pub C); + +impl<C: Conversation> Conversation for ConversationMux<C> { + fn send(&mut self, messages: &[Message]) -> Result<Vec<Response>> { + self.0.send(messages) + } +} + +impl<C: Conversation> ConversationMux<C> { + /// Prompts the user for something. + pub fn prompt(&mut self, request: &str) -> Result<String> { + let resp = self.send(&[Message::Prompt(request)])?.pop(); + match resp { + Some(Response::Text(s)) => Ok(s), + _ => Err(ErrorCode::ConversationError), + } + } + + /// Prompts the user for something, but hides what the user types. + pub fn masked_prompt(&mut self, request: &str) -> Result<SecureString> { + let resp = self.send(&[Message::MaskedPrompt(request)])?.pop(); + match resp { + Some(Response::MaskedText(s)) => Ok(s), + _ => Err(ErrorCode::ConversationError), + } + } + + /// Prompts the user for a yes/no/maybe conditional (a Linux-PAM extension). + /// + /// PAM documentation doesn't define the format of the response. + pub fn radio_prompt(&mut self, request: &str) -> Result<String> { + let resp = self.send(&[Message::RadioPrompt(request)])?.pop(); + match resp { + Some(Response::Text(s)) => Ok(s), + _ => Err(ErrorCode::ConversationError), + } + } + + /// Alerts the user to an error. + pub fn error(&mut self, message: &str) { + let _ = self.send(&[Message::Error(message)]); + } + + /// Sends an informational message to the user. + pub fn info(&mut self, message: &str) { + let _ = self.send(&[Message::Info(message)]); + } + + /// Requests binary data from the user (a Linux-PAM extension). + pub fn binary_prompt(&mut self, data: &[u8], data_type: u8) -> Result<BinaryData> { + let resp = self + .send(&[Message::BinaryPrompt { data, data_type }])? + .pop(); + match resp { + Some(Response::Binary(d)) => Ok(d), + _ => Err(ErrorCode::ConversationError), + } + } +} + /// Generates the dynamic library entry points for a [PamModule] implementation. /// /// Calling `pam_hooks!(SomeType)` on a type that implements [PamModule] will @@ -379,11 +460,91 @@ } #[cfg(test)] -pub mod test { - use crate::module::{PamModule, PamModuleHandle}; +mod test { + use super::{ + Conversation, ConversationMux, ErrorCode, Message, Response, Result, SecureString, + }; + + /// Compile-time test that the `pam_hooks` macro compiles. + mod hooks { + use super::super::{PamModule, PamModuleHandle}; + struct Foo; + impl<T: PamModuleHandle> PamModule<T> for Foo {} + + pam_hooks!(Foo); + } + + #[test] + fn test_mux() { + struct MuxTester; - struct Foo; - impl<T: PamModuleHandle> PamModule<T> for Foo {} + impl Conversation for MuxTester { + fn send(&mut self, messages: &[Message]) -> Result<Vec<Response>> { + if let [msg] = messages { + match msg { + Message::Info(info) => { + assert_eq!("let me tell you", *info); + Ok(vec![Response::NoResponse]) + } + Message::Error(error) => { + assert_eq!("oh no", *error); + Ok(vec![Response::NoResponse]) + } + Message::Prompt("should_error") => Err(ErrorCode::BufferError), + Message::Prompt(ask) => { + assert_eq!("question", *ask); + Ok(vec![Response::Text("answer".to_owned())]) + } + Message::MaskedPrompt("return_wrong_type") => { + Ok(vec![Response::NoResponse]) + } + Message::MaskedPrompt(ask) => { + assert_eq!("password!", *ask); + Ok(vec![Response::MaskedText(SecureString::from( + "open sesame", + ))]) + } + Message::BinaryPrompt { data, data_type } => { + assert_eq!(&[1, 2, 3], data); + assert_eq!(69, *data_type); + Ok(vec![Response::Binary(super::BinaryData::new( + vec![3, 2, 1], + 42, + ))]) + } + Message::RadioPrompt(ask) => { + assert_eq!("radio?", *ask); + Ok(vec![Response::Text("yes".to_owned())]) + } + } + } else { + panic!("messages is the wrong size ({len})", len = messages.len()) + } + } + } - pam_hooks!(Foo); + let mut mux = ConversationMux(MuxTester); + assert_eq!("answer", mux.prompt("question").unwrap()); + assert_eq!( + SecureString::from("open sesame"), + mux.masked_prompt("password!").unwrap() + ); + mux.error("oh no"); + mux.info("let me tell you"); + { + assert_eq!("yes", mux.radio_prompt("radio?").unwrap()); + assert_eq!( + super::BinaryData::new(vec![3, 2, 1], 42), + mux.binary_prompt(&[1, 2, 3], 69).unwrap(), + ) + } + assert_eq!( + ErrorCode::BufferError, + mux.prompt("should_error").unwrap_err(), + ); + assert_eq!( + ErrorCode::ConversationError, + mux.masked_prompt("return_wrong_type").unwrap_err() + ) + } }