diff src/pam_ffi/memory.rs @ 71:58f9d2a4df38

Reorganize everything again??? - Splits ffi/memory stuff into a bunch of stuff in the pam_ffi module. - Builds infrastructure for passing Messages and Responses. - Adds tests for some things at least.
author Paul Fisher <paul@pfish.zone>
date Tue, 03 Jun 2025 21:54:58 -0400
parents src/memory.rs@bbe84835d6db
children 47eb242a4f88
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/pam_ffi/memory.rs	Tue Jun 03 21:54:58 2025 -0400
@@ -0,0 +1,187 @@
+//! Things for dealing with memory.
+
+use crate::ErrorCode;
+use crate::Result;
+use std::ffi::{c_char, c_void, CStr, CString};
+use std::marker::{PhantomData, PhantomPinned};
+use std::result::Result as StdResult;
+use std::{ptr, slice};
+
+/// Makes whatever it's in not [`Send`], [`Sync`], or [`Unpin`].
+pub type Immovable = PhantomData<(*mut u8, PhantomPinned)>;
+
+/// Safely converts a `&str` option to a `CString` option.
+pub fn option_cstr(prompt: Option<&str>) -> Result<Option<CString>> {
+    prompt
+        .map(CString::new)
+        .transpose()
+        .map_err(|_| ErrorCode::ConversationError)
+}
+
+/// Gets the pointer to the given CString, or a null pointer if absent.
+pub fn prompt_ptr(prompt: Option<&CString>) -> *const c_char {
+    match prompt {
+        Some(c_str) => c_str.as_ptr(),
+        None => ptr::null(),
+    }
+}
+
+/// Creates an owned copy of a string that is returned from a
+/// <code>pam_get_<var>whatever</var></code> function.
+pub unsafe fn copy_pam_string(result_ptr: *const libc::c_char) -> Result<String> {
+    // We really shouldn't get a null pointer back here, but if we do, return nothing.
+    if result_ptr.is_null() {
+        return Ok(String::new());
+    }
+    let bytes = unsafe { CStr::from_ptr(result_ptr) };
+    bytes
+        .to_str()
+        .map(String::from)
+        .map_err(|_| ErrorCode::ConversationError)
+}
+
+/// Allocates a string with the given contents on the C heap.
+///
+/// This is like [`CString::new`](std::ffi::CString::new), but:
+///
+/// - it allocates data on the C heap with [`libc::malloc`].
+/// - it doesn't take ownership of the data passed in.
+pub fn malloc_str(text: impl AsRef<str>) -> StdResult<*mut c_char, NulError> {
+    let data = text.as_ref().as_bytes();
+    if let Some(nul) = data.iter().position(|x| *x == 0) {
+        return Err(NulError(nul));
+    }
+    unsafe {
+        let data_alloc = libc::calloc(data.len() + 1, 1);
+        libc::memcpy(data_alloc, data.as_ptr() as *const c_void, data.len());
+        Ok(data_alloc.cast())
+    }
+}
+
+/// Writes zeroes over the contents of a C string.
+///
+/// This won't overwrite a null pointer.
+///
+/// # Safety
+///
+/// It's up to you to provide a valid C string.
+pub unsafe fn zero_c_string(cstr: *mut c_void) {
+    if !cstr.is_null() {
+        libc::memset(cstr, 0, libc::strlen(cstr as *const c_char));
+    }
+}
+
+/// Binary data used in requests and responses.
+///
+/// This is an unsized data type whose memory goes beyond its data.
+/// This must be allocated on the C heap.
+///
+/// A Linux-PAM extension.
+#[repr(C)]
+pub struct CBinaryData {
+    /// The total length of the structure; a u32 in network byte order (BE).
+    total_length: [u8; 4],
+    /// A tag of undefined meaning.
+    data_type: u8,
+    /// Pointer to an array of length [`length`](Self::length) − 5
+    data: [u8; 0],
+    _marker: Immovable,
+}
+
+impl CBinaryData {
+    /// Copies the given data to a new BinaryData on the heap.
+    pub fn alloc(source: &[u8], data_type: u8) -> StdResult<*mut CBinaryData, TooBigError> {
+        let buffer_size = u32::try_from(source.len() + 5).map_err(|_| TooBigError {
+            max: (u32::MAX - 5) as usize,
+            actual: source.len(),
+        })?;
+        let data = unsafe {
+            let dest_buffer = libc::malloc(buffer_size as usize) as *mut CBinaryData;
+            let data = &mut *dest_buffer;
+            data.total_length = buffer_size.to_be_bytes();
+            data.data_type = data_type;
+            let dest = data.data.as_mut_ptr();
+            libc::memcpy(
+                dest as *mut c_void,
+                source.as_ptr() as *const c_void,
+                source.len(),
+            );
+            dest_buffer
+        };
+        Ok(data)
+    }
+
+    fn length(&self) -> usize {
+        u32::from_be_bytes(self.total_length).saturating_sub(5) as usize
+    }
+
+    pub fn contents(&self) -> &[u8] {
+        unsafe { slice::from_raw_parts(self.data.as_ptr(), self.length()) }
+    }
+    pub fn data_type(&self) -> u8 {
+        self.data_type
+    }
+
+    /// Clears this data and frees it.
+    pub unsafe fn zero_contents(&mut self) {
+        let contents = slice::from_raw_parts_mut(self.data.as_mut_ptr(), self.length());
+        for v in contents {
+            *v = 0
+        }
+        self.data_type = 0;
+        self.total_length = [0; 4];
+    }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error("null byte within input at byte {0}")]
+pub struct NulError(pub usize);
+
+/// Returned when trying to fit too much data into a binary message.
+#[derive(Debug, thiserror::Error)]
+#[error("cannot create a message of {actual} bytes; maximum is {max}")]
+pub struct TooBigError {
+    pub actual: usize,
+    pub max: usize,
+}
+
+#[cfg(test)]
+mod tests {
+    use std::ffi::CString;
+    use crate::ErrorCode;
+    use super::{copy_pam_string, malloc_str, option_cstr, prompt_ptr, zero_c_string};
+    #[test]
+    fn test_strings() {
+        let str = malloc_str("hello there").unwrap();
+        malloc_str("hell\0 there").unwrap_err();
+        unsafe {
+            let copied = copy_pam_string(str.cast()).unwrap();
+            assert_eq!("hello there", copied);
+            zero_c_string(str.cast());
+            let idx_three = str.add(3).as_mut().unwrap();
+            *idx_three = 0x80u8 as i8;
+            let zeroed = copy_pam_string(str.cast()).unwrap();
+            assert!(zeroed.is_empty());
+            libc::free(str.cast());
+        }
+    }
+    
+    #[test]
+    fn test_option_str() {
+        let good = option_cstr(Some("whatever")).unwrap();
+        assert_eq!("whatever", good.unwrap().to_str().unwrap());
+        let no_str = option_cstr(None).unwrap();
+        assert!(no_str.is_none());
+        let bad_str = option_cstr(Some("what\0ever")).unwrap_err();
+        assert_eq!(ErrorCode::ConversationError, bad_str);
+    }
+    
+    #[test]
+    fn test_prompt() {
+        let prompt_cstr = CString::new("good").ok();
+        let prompt = prompt_ptr(prompt_cstr.as_ref());
+        assert!(!prompt.is_null());
+        let no_prompt = prompt_ptr(None);
+        assert!(no_prompt.is_null());
+    }
+}