diff src/libpam/memory.rs @ 78:002adfb98c5c

Rename files, reorder structs, remove annoying BorrowedBinaryData type. This is basically a cleanup change. Also it adds tests. - Renames the files with Questions and Answers to question and answer. - Reorders the structs in those files to put the important ones first. - Removes the BorrowedBinaryData type. It was a bad idea all along. Instead, we just use (&[u8], u8). - Adds some tests because I just can't help myself.
author Paul Fisher <paul@pfish.zone>
date Sun, 08 Jun 2025 03:48:40 -0400
parents 351bdc13005e
children 2128123b9406
line wrap: on
line diff
--- a/src/libpam/memory.rs	Sun Jun 08 01:03:46 2025 -0400
+++ b/src/libpam/memory.rs	Sun Jun 08 03:48:40 2025 -0400
@@ -1,12 +1,28 @@
 //! Things for dealing with memory.
 
-use crate::conv::BorrowedBinaryData;
 use crate::Result;
 use crate::{BinaryData, ErrorCode};
-use std::ffi::{c_char, c_void, CStr, CString};
+use std::ffi::{c_char, CStr, CString};
 use std::marker::{PhantomData, PhantomPinned};
 use std::{ptr, slice};
 
+/// Allocates `count` elements to hold `T`.
+#[inline]
+pub fn calloc<T>(count: usize) -> *mut T {
+    // SAFETY: it's always safe to allocate! Leaking memory is fun!
+    unsafe { libc::calloc(count, size_of::<T>()) }.cast()
+}
+
+/// Wrapper for [`libc::free`] to make debugging calls/frees easier.
+///
+/// # Safety
+///
+/// If you double-free, it's all your fault.
+#[inline]
+pub unsafe fn free<T>(p: *mut T) {
+    libc::free(p.cast())
+}
+
 /// Makes whatever it's in not [`Send`], [`Sync`], or [`Unpin`].
 #[repr(C)]
 #[derive(Debug)]
@@ -62,7 +78,7 @@
 
 /// Allocates a string with the given contents on the C heap.
 ///
-/// This is like [`CString::new`](std::ffi::CString::new), but:
+/// This is like [`CString::new`], but:
 ///
 /// - it allocates data on the C heap with [`libc::malloc`].
 /// - it doesn't take ownership of the data passed in.
@@ -71,11 +87,12 @@
     if data.contains(&0) {
         return Err(ErrorCode::ConversationError);
     }
+    let data_alloc: *mut c_char = calloc(data.len() + 1);
+    // SAFETY: we just allocated this and we have enough room.
     unsafe {
-        let data_alloc = libc::calloc(data.len() + 1, 1);
-        libc::memcpy(data_alloc, data.as_ptr().cast(), data.len());
-        Ok(data_alloc.cast())
+        libc::memcpy(data_alloc.cast(), data.as_ptr().cast(), data.len());
     }
+    Ok(data_alloc)
 }
 
 /// Writes zeroes over the contents of a C string.
@@ -85,9 +102,9 @@
 /// # Safety
 ///
 /// It's up to you to provide a valid C string.
-pub unsafe fn zero_c_string(cstr: *mut c_void) {
+pub unsafe fn zero_c_string(cstr: *mut c_char) {
     if !cstr.is_null() {
-        libc::memset(cstr, 0, libc::strlen(cstr.cast()));
+        libc::memset(cstr.cast(), 0, libc::strlen(cstr.cast()));
     }
 }
 
@@ -110,20 +127,20 @@
 
 impl CBinaryData {
     /// Copies the given data to a new BinaryData on the heap.
-    pub fn alloc(source: &[u8], data_type: u8) -> Result<*mut CBinaryData> {
+    pub fn alloc((data, data_type): (&[u8], u8)) -> Result<*mut CBinaryData> {
         let buffer_size =
-            u32::try_from(source.len() + 5).map_err(|_| ErrorCode::ConversationError)?;
+            u32::try_from(data.len() + 5).map_err(|_| ErrorCode::ConversationError)?;
         // SAFETY: We're only allocating here.
-        let data = unsafe {
-            let dest_buffer: *mut CBinaryData = libc::malloc(buffer_size as usize).cast();
-            let data = &mut *dest_buffer;
-            data.total_length = buffer_size.to_be_bytes();
-            data.data_type = data_type;
-            let dest = data.data.as_mut_ptr();
-            libc::memcpy(dest.cast(), source.as_ptr().cast(), source.len());
+        let dest = unsafe {
+            let dest_buffer: *mut CBinaryData = calloc::<u8>(buffer_size as usize).cast();
+            let dest = &mut *dest_buffer;
+            dest.total_length = buffer_size.to_be_bytes();
+            dest.data_type = data_type;
+            let dest = dest.data.as_mut_ptr();
+            libc::memcpy(dest.cast(), data.as_ptr().cast(), data.len());
             dest_buffer
         };
-        Ok(data)
+        Ok(dest)
     }
 
     fn length(&self) -> usize {
@@ -141,39 +158,42 @@
     }
 }
 
-impl<'a> From<&'a CBinaryData> for BorrowedBinaryData<'a> {
+impl<'a> From<&'a CBinaryData> for (&'a[u8], u8) {
     fn from(value: &'a CBinaryData) -> Self {
-        BorrowedBinaryData::new(
-            unsafe { slice::from_raw_parts(value.data.as_ptr(), value.length()) },
-            value.data_type,
-        )
+        (unsafe { slice::from_raw_parts(value.data.as_ptr(), value.length()) },
+            value.data_type        )
     }
 }
 
 impl From<Option<&'_ CBinaryData>> for BinaryData {
     fn from(value: Option<&CBinaryData>) -> Self {
-        value.map(BorrowedBinaryData::from).map(Into::into).unwrap_or_default()
+        // This is a dumb trick but I like it because it is simply the presence
+        // of `.map(|(x, y)| (x, y))` in the middle of this that gives
+        // type inference the hint it needs to make this work.
+        value
+            .map(Into::into)
+            .map(|(data, data_type)| (data, data_type))
+            .map(Into::into)
+            .unwrap_or_default()
     }
 }
 
 #[cfg(test)]
 mod tests {
-    use super::{copy_pam_string, malloc_str, option_cstr, prompt_ptr, zero_c_string};
-    use crate::ErrorCode;
-    use std::ffi::CString;
+    use super::{free, ErrorCode, CString, copy_pam_string, malloc_str, option_cstr, prompt_ptr, zero_c_string};
     #[test]
     fn test_strings() {
         let str = malloc_str("hello there").unwrap();
         malloc_str("hell\0 there").unwrap_err();
         unsafe {
-            let copied = copy_pam_string(str.cast()).unwrap();
+            let copied = copy_pam_string(str).unwrap();
             assert_eq!("hello there", copied);
-            zero_c_string(str.cast());
+            zero_c_string(str);
             let idx_three = str.add(3).as_mut().unwrap();
             *idx_three = 0x80u8 as i8;
-            let zeroed = copy_pam_string(str.cast()).unwrap();
+            let zeroed = copy_pam_string(str).unwrap();
             assert!(zeroed.is_empty());
-            libc::free(str.cast());
+            free(str);
         }
     }