view src/libpam/memory.rs @ 79:2128123b9406

Format (oops!) and make some fun and/or stupid conversions available.
author Paul Fisher <paul@pfish.zone>
date Sun, 08 Jun 2025 04:21:58 -0400
parents 002adfb98c5c
children 5aa1a010f1e8
line wrap: on
line source

//! Things for dealing with memory.

use crate::Result;
use crate::{BinaryData, ErrorCode};
use std::ffi::{c_char, CStr, CString};
use std::marker::{PhantomData, PhantomPinned};
use std::{ptr, slice};

/// Allocates `count` elements to hold `T`.
#[inline]
pub fn calloc<T>(count: usize) -> *mut T {
    // SAFETY: it's always safe to allocate! Leaking memory is fun!
    unsafe { libc::calloc(count, size_of::<T>()) }.cast()
}

/// Wrapper for [`libc::free`] to make debugging calls/frees easier.
///
/// # Safety
///
/// If you double-free, it's all your fault.
#[inline]
pub unsafe fn free<T>(p: *mut T) {
    libc::free(p.cast())
}

/// Makes whatever it's in not [`Send`], [`Sync`], or [`Unpin`].
#[repr(C)]
#[derive(Debug)]
pub struct Immovable(pub PhantomData<(*mut u8, PhantomPinned)>);

/// Safely converts a `&str` option to a `CString` option.
pub fn option_cstr(prompt: Option<&str>) -> Result<Option<CString>> {
    prompt
        .map(CString::new)
        .transpose()
        .map_err(|_| ErrorCode::ConversationError)
}

/// Gets the pointer to the given CString, or a null pointer if absent.
pub fn prompt_ptr(prompt: Option<&CString>) -> *const c_char {
    match prompt {
        Some(c_str) => c_str.as_ptr(),
        None => ptr::null(),
    }
}

/// Creates an owned copy of a string that is returned from a
/// <code>pam_get_<var>whatever</var></code> function.
///
/// # Safety
///
/// It's on you to provide a valid string.
pub unsafe fn copy_pam_string(result_ptr: *const libc::c_char) -> Result<String> {
    // We really shouldn't get a null pointer back here, but if we do, return nothing.
    if result_ptr.is_null() {
        return Ok(String::new());
    }
    let bytes = unsafe { CStr::from_ptr(result_ptr) };
    bytes
        .to_str()
        .map(String::from)
        .map_err(|_| ErrorCode::ConversationError)
}

/// Wraps a string returned from PAM as an `Option<&str>`.
pub unsafe fn wrap_string<'a>(data: *const libc::c_char) -> Result<Option<&'a str>> {
    let ret = if data.is_null() {
        None
    } else {
        Some(
            CStr::from_ptr(data)
                .to_str()
                .map_err(|_| ErrorCode::ConversationError)?,
        )
    };
    Ok(ret)
}

/// Allocates a string with the given contents on the C heap.
///
/// This is like [`CString::new`], but:
///
/// - it allocates data on the C heap with [`libc::malloc`].
/// - it doesn't take ownership of the data passed in.
pub fn malloc_str(text: &str) -> Result<*mut c_char> {
    let data = text.as_bytes();
    if data.contains(&0) {
        return Err(ErrorCode::ConversationError);
    }
    // +1 for the null terminator
    let data_alloc: *mut c_char = calloc(data.len() + 1);
    // SAFETY: we just allocated this and we have enough room.
    unsafe {
        libc::memcpy(data_alloc.cast(), data.as_ptr().cast(), data.len());
    }
    Ok(data_alloc)
}

/// Writes zeroes over the contents of a C string.
///
/// This won't overwrite a null pointer.
///
/// # Safety
///
/// It's up to you to provide a valid C string.
pub unsafe fn zero_c_string(cstr: *mut c_char) {
    if !cstr.is_null() {
        libc::memset(cstr.cast(), 0, libc::strlen(cstr.cast()));
    }
}

/// Binary data used in requests and responses.
///
/// This is an unsized data type whose memory goes beyond its data.
/// This must be allocated on the C heap.
///
/// A Linux-PAM extension.
#[repr(C)]
pub struct CBinaryData {
    /// The total length of the structure; a u32 in network byte order (BE).
    total_length: [u8; 4],
    /// A tag of undefined meaning.
    data_type: u8,
    /// Pointer to an array of length [`length`](Self::length) − 5
    data: [u8; 0],
    _marker: Immovable,
}

impl CBinaryData {
    /// Copies the given data to a new BinaryData on the heap.
    pub fn alloc((data, data_type): (&[u8], u8)) -> Result<*mut CBinaryData> {
        let buffer_size =
            u32::try_from(data.len() + 5).map_err(|_| ErrorCode::ConversationError)?;
        // SAFETY: We're only allocating here.
        let dest = unsafe {
            let dest_buffer: *mut CBinaryData = calloc::<u8>(buffer_size as usize).cast();
            let dest = &mut *dest_buffer;
            dest.total_length = buffer_size.to_be_bytes();
            dest.data_type = data_type;
            let dest = dest.data.as_mut_ptr();
            libc::memcpy(dest.cast(), data.as_ptr().cast(), data.len());
            dest_buffer
        };
        Ok(dest)
    }

    fn length(&self) -> usize {
        u32::from_be_bytes(self.total_length).saturating_sub(5) as usize
    }

    /// Clears this data and frees it.
    pub unsafe fn zero_contents(&mut self) {
        let contents = slice::from_raw_parts_mut(self.data.as_mut_ptr(), self.length());
        for v in contents {
            *v = 0
        }
        self.data_type = 0;
        self.total_length = [0; 4];
    }
}

impl<'a> From<&'a CBinaryData> for (&'a [u8], u8) {
    fn from(value: &'a CBinaryData) -> Self {
        (
            unsafe { slice::from_raw_parts(value.data.as_ptr(), value.length()) },
            value.data_type,
        )
    }
}

impl From<&'_ CBinaryData> for BinaryData {
    fn from(value: &'_ CBinaryData) -> Self {
        // This is a dumb trick but I like it because it is simply the presence
        // of `.map(|z: (_, _)| z)` in the middle of this that gives
        // type inference the hint it needs to make this work.
        let [ret] = [value].map(Into::into).map(|z: (_, _)| z).map(Into::into);
        ret
    }
}

impl From<Option<&'_ CBinaryData>> for BinaryData {
    fn from(value: Option<&CBinaryData>) -> Self {
        value.map(Into::into).unwrap_or_default()
    }
}

#[cfg(test)]
mod tests {
    use super::{
        copy_pam_string, free, malloc_str, option_cstr, prompt_ptr, zero_c_string, CString,
        ErrorCode,
    };
    #[test]
    fn test_strings() {
        let str = malloc_str("hello there").unwrap();
        malloc_str("hell\0 there").unwrap_err();
        unsafe {
            let copied = copy_pam_string(str).unwrap();
            assert_eq!("hello there", copied);
            zero_c_string(str);
            let idx_three = str.add(3).as_mut().unwrap();
            *idx_three = 0x80u8 as i8;
            let zeroed = copy_pam_string(str).unwrap();
            assert!(zeroed.is_empty());
            free(str);
        }
    }

    #[test]
    fn test_option_str() {
        let good = option_cstr(Some("whatever")).unwrap();
        assert_eq!("whatever", good.unwrap().to_str().unwrap());
        let no_str = option_cstr(None).unwrap();
        assert!(no_str.is_none());
        let bad_str = option_cstr(Some("what\0ever")).unwrap_err();
        assert_eq!(ErrorCode::ConversationError, bad_str);
    }

    #[test]
    fn test_prompt() {
        let prompt_cstr = CString::new("good").ok();
        let prompt = prompt_ptr(prompt_cstr.as_ref());
        assert!(!prompt.is_null());
        let no_prompt = prompt_ptr(None);
        assert!(no_prompt.is_null());
    }
}