view src/libpam/memory.rs @ 95:51c9d7e8261a

Return owned strings rather than borrowed strings. It's going to be irritating to have to work with strings borrowed from the PAM handle rather than just using your own. They're cheap enough to copy.
author Paul Fisher <paul@pfish.zone>
date Mon, 23 Jun 2025 14:03:44 -0400
parents efc2b56c8928
children b87100c5eed4
line wrap: on
line source

//! Things for dealing with memory.

use crate::Result;
use crate::{BinaryData, ErrorCode};
use std::ffi::{c_char, CStr, CString};
use std::marker::{PhantomData, PhantomPinned};
use std::mem::offset_of;
use std::ptr::NonNull;
use std::{mem, ptr, slice};

/// Allocates `count` elements to hold `T`.
#[inline]
pub fn calloc<T>(count: usize) -> *mut T {
    // SAFETY: it's always safe to allocate! Leaking memory is fun!
    unsafe { libc::calloc(count, size_of::<T>()) }.cast()
}

/// Wrapper for [`libc::free`] to make debugging calls/frees easier.
///
/// # Safety
///
/// If you double-free, it's all your fault.
#[inline]
pub unsafe fn free<T>(p: *mut T) {
    libc::free(p.cast())
}

/// Makes whatever it's in not [`Send`], [`Sync`], or [`Unpin`].
#[repr(C)]
#[derive(Debug, Default)]
pub struct Immovable(pub PhantomData<(*mut u8, PhantomPinned)>);

/// Safely converts a `&str` option to a `CString` option.
pub fn option_cstr(prompt: Option<&str>) -> Result<Option<CString>> {
    prompt
        .map(CString::new)
        .transpose()
        .map_err(|_| ErrorCode::ConversationError)
}

/// Gets the pointer to the given CString, or a null pointer if absent.
pub fn prompt_ptr(prompt: Option<&CString>) -> *const c_char {
    match prompt {
        Some(c_str) => c_str.as_ptr(),
        None => ptr::null(),
    }
}

/// Creates an owned copy of a string that is returned from a
/// <code>pam_get_<var>whatever</var></code> function.
///
/// # Safety
///
/// It's on you to provide a valid string.
pub unsafe fn copy_pam_string(result_ptr: *const c_char) -> Result<Option<String>> {
    let borrowed = match NonNull::new(result_ptr.cast_mut()) {
        Some(data) => Some(
            CStr::from_ptr(data.as_ptr())
                .to_str()
                .map_err(|_| ErrorCode::ConversationError)?,
        ),
        None => return Ok(None),
    };
    Ok(borrowed.map(String::from))
}

/// Allocates a string with the given contents on the C heap.
///
/// This is like [`CString::new`], but:
///
/// - it allocates data on the C heap with [`libc::malloc`].
/// - it doesn't take ownership of the data passed in.
pub fn malloc_str(text: &str) -> Result<NonNull<c_char>> {
    let data = text.as_bytes();
    if data.contains(&0) {
        return Err(ErrorCode::ConversationError);
    }
    // +1 for the null terminator
    let data_alloc: *mut c_char = calloc(data.len() + 1);
    // SAFETY: we just allocated this and we have enough room.
    unsafe {
        libc::memcpy(data_alloc.cast(), data.as_ptr().cast(), data.len());
        Ok(NonNull::new_unchecked(data_alloc))
    }
}

/// Writes zeroes over the contents of a C string.
///
/// This won't overwrite a null pointer.
///
/// # Safety
///
/// It's up to you to provide a valid C string.
pub unsafe fn zero_c_string(cstr: *mut c_char) {
    if !cstr.is_null() {
        let len = libc::strlen(cstr.cast());
        for x in 0..len {
            ptr::write_volatile(cstr.byte_offset(x as isize), mem::zeroed())
        }
    }
}

/// Binary data used in requests and responses.
///
/// This is an unsized data type whose memory goes beyond its data.
/// This must be allocated on the C heap.
///
/// A Linux-PAM extension.
#[repr(C)]
pub struct CBinaryData {
    /// The total length of the structure; a u32 in network byte order (BE).
    total_length: [u8; 4],
    /// A tag of undefined meaning.
    data_type: u8,
    /// Pointer to an array of length [`length`](Self::length) − 5
    data: [u8; 0],
    _marker: Immovable,
}

impl CBinaryData {
    /// Copies the given data to a new BinaryData on the heap.
    pub fn alloc((data, data_type): (&[u8], u8)) -> Result<NonNull<CBinaryData>> {
        let buffer_size =
            u32::try_from(data.len() + 5).map_err(|_| ErrorCode::ConversationError)?;
        // SAFETY: We're only allocating here.
        let dest = unsafe {
            let mut dest_buffer: NonNull<Self> =
                NonNull::new_unchecked(calloc::<u8>(buffer_size as usize).cast());
            let dest = dest_buffer.as_mut();
            dest.total_length = buffer_size.to_be_bytes();
            dest.data_type = data_type;
            libc::memcpy(
                Self::data_ptr(dest_buffer).cast(),
                data.as_ptr().cast(),
                data.len(),
            );
            dest_buffer
        };
        Ok(dest)
    }

    fn length(&self) -> usize {
        u32::from_be_bytes(self.total_length).saturating_sub(5) as usize
    }

    fn data_ptr(ptr: NonNull<Self>) -> *mut u8 {
        unsafe {
            ptr.as_ptr()
                .cast::<u8>()
                .byte_offset(offset_of!(Self, data) as isize)
        }
    }

    unsafe fn data_slice<'a>(ptr: NonNull<Self>) -> &'a mut [u8] {
        unsafe { slice::from_raw_parts_mut(Self::data_ptr(ptr), ptr.as_ref().length()) }
    }

    pub unsafe fn data<'a>(ptr: NonNull<Self>) -> (&'a [u8], u8) {
        unsafe { (Self::data_slice(ptr), ptr.as_ref().data_type) }
    }

    pub unsafe fn zero_contents(ptr: NonNull<Self>) {
        for byte in Self::data_slice(ptr) {
            ptr::write_volatile(byte as *mut u8, mem::zeroed());
        }
        ptr::write_volatile(ptr.as_ptr(), mem::zeroed());
    }

    #[allow(clippy::wrong_self_convention)]
    pub unsafe fn as_binary_data(ptr: NonNull<Self>) -> BinaryData {
        let (data, data_type) = unsafe { (CBinaryData::data_slice(ptr), ptr.as_ref().data_type) };
        (Vec::from(data), data_type).into()
    }
}

#[cfg(test)]
mod tests {
    use super::{
        copy_pam_string, free, malloc_str, option_cstr, prompt_ptr, zero_c_string, CString,
        ErrorCode,
    };
    #[test]
    fn test_strings() {
        let str = malloc_str("hello there").unwrap();
        let str = str.as_ptr();
        malloc_str("hell\0 there").unwrap_err();
        unsafe {
            let copied = copy_pam_string(str).unwrap();
            assert_eq!("hello there", copied.unwrap());
            zero_c_string(str);
            let idx_three = str.add(3).as_mut().unwrap();
            *idx_three = 0x80u8 as i8;
            let zeroed = copy_pam_string(str).unwrap().unwrap();
            assert!(zeroed.is_empty());
            free(str);
        }
    }

    #[test]
    fn test_option_str() {
        let good = option_cstr(Some("whatever")).unwrap();
        assert_eq!("whatever", good.unwrap().to_str().unwrap());
        let no_str = option_cstr(None).unwrap();
        assert!(no_str.is_none());
        let bad_str = option_cstr(Some("what\0ever")).unwrap_err();
        assert_eq!(ErrorCode::ConversationError, bad_str);
    }

    #[test]
    fn test_prompt() {
        let prompt_cstr = CString::new("good").ok();
        let prompt = prompt_ptr(prompt_cstr.as_ref());
        assert!(!prompt.is_null());
        let no_prompt = prompt_ptr(None);
        assert!(no_prompt.is_null());
    }
}