view src/libpam/environ.rs @ 103:dfcd96a74ac4 default tip

write a truly prodigious amount of documentation adds a bunch of links to the OpenPAM man pages and the XSSO spec as well as just a bunch of prose and stuff.
author Paul Fisher <paul@pfish.zone>
date Wed, 25 Jun 2025 00:59:24 -0400
parents 3f11b8d30f63
children
line wrap: on
line source

use crate::constants::{ErrorCode, Result};
use crate::environ::{EnvironMap, EnvironMapMut};
use crate::libpam::memory::CHeapString;
use crate::libpam::{memory, pam_ffi, LibPamHandle};
use std::ffi::{c_char, CStr, CString, OsStr, OsString};
use std::marker::PhantomData;
use std::os::unix::ffi::{OsStrExt, OsStringExt};
use std::ptr;
use std::ptr::NonNull;

pub struct LibPamEnviron<'a> {
    source: &'a LibPamHandle,
}

pub struct LibPamEnvironMut<'a> {
    source: &'a mut LibPamHandle,
}

impl LibPamHandle {
    fn environ_get(&self, key: &OsStr) -> Option<OsString> {
        let key = CString::new(key.as_bytes()).ok()?;
        // SAFETY: We are a valid handle and are calling with a good key.
        unsafe {
            copy_env(pam_ffi::pam_getenv(
                (self as *const LibPamHandle).cast_mut(),
                key.as_ptr(),
            ))
        }
    }

    fn environ_set(&mut self, key: &OsStr, value: Option<&OsStr>) -> Result<Option<OsString>> {
        let old = self.environ_get(key);
        if old.is_none() && value.is_none() {
            // pam_putenv returns an error if we try to remove a non-existent
            // environment variable, so just avoid that entirely.
            return Ok(None)
        }
        let total_len = key.len() + value.map(OsStr::len).unwrap_or_default() + 2;
        let mut result = Vec::with_capacity(total_len);
        result.extend(key.as_bytes());
        if let Some(value) = value {
            result.push(b'=');
            result.extend(value.as_bytes());
        }
        let put = CString::new(result).map_err(|_| ErrorCode::ConversationError)?;
        // SAFETY: This is a valid handle and a valid environment string.
        ErrorCode::result_from(unsafe { pam_ffi::pam_putenv(self, put.as_ptr()) })?;
        Ok(old)
    }

    fn environ_iter(&self) -> Result<impl Iterator<Item = (OsString, OsString)>> {
        // SAFETY: This is a valid PAM handle. It will return valid data.
        unsafe {
            NonNull::new(pam_ffi::pam_getenvlist(
                (self as *const LibPamHandle).cast_mut(),
            ))
            .map(|ptr| EnvList::from_ptr(ptr.cast()))
            .ok_or(ErrorCode::BufferError)
        }
    }
}

/// Copies the data of the given C string pointer to an OsString,
/// or None if src is null.
unsafe fn copy_env(src: *const c_char) -> Option<OsString> {
    let val = match NonNull::new(src.cast_mut()) {
        None => return None,
        Some(ptr) => ptr.as_ptr(),
    };
    // SAFETY: We were just returned this string from PAM.
    // We have to trust it.
    let c_str = unsafe { CStr::from_ptr(val) };
    Some(OsString::from_vec(c_str.to_bytes().to_vec()))
}

impl<'a> LibPamEnviron<'a> {
    pub fn new(source: &'a LibPamHandle) -> Self {
        Self { source }
    }
}

impl<'a> LibPamEnvironMut<'a> {
    pub fn new(source: &'a mut LibPamHandle) -> Self {
        Self { source }
    }
}

impl EnvironMap<'_> for LibPamEnviron<'_> {
    fn get(&self, key: impl AsRef<OsStr>) -> Option<OsString> {
        self.source.environ_get(key.as_ref())
    }

    fn iter(&self) -> Result<impl Iterator<Item = (OsString, OsString)>> {
        self.source.environ_iter()
    }
}

impl EnvironMap<'_> for LibPamEnvironMut<'_> {
    fn get(&self, key: impl AsRef<OsStr>) -> Option<OsString> {
        self.source.environ_get(key.as_ref())
    }

    fn iter(&self) -> Result<impl Iterator<Item = (OsString, OsString)>> {
        self.source.environ_iter()
    }
}

impl EnvironMapMut<'_> for LibPamEnvironMut<'_> {
    fn insert(
        &mut self,
        key: impl AsRef<OsStr>,
        val: impl AsRef<OsStr>,
    ) -> Result<Option<OsString>> {
        self.source.environ_set(key.as_ref(), Some(val.as_ref()))
    }

    fn remove(&mut self, key: impl AsRef<OsStr>) -> Result<Option<OsString>> {
        self.source.environ_set(key.as_ref(), None)
    }
}

struct EnvList<'a> {
    /// Pointer to the start of the environment variable list.
    ///
    /// This can't be a `CHeapBox` because it's not just a single
    /// `Option<EnvVar>`.
    start: NonNull<Option<EnvVar>>,
    /// The environment variable we're about to iterate into.
    current: NonNull<Option<EnvVar>>,
    _owner: PhantomData<&'a LibPamHandle>,
}

impl EnvList<'_> {
    unsafe fn from_ptr(ptr: NonNull<*mut c_char>) -> Self {
        Self {
            start: ptr.cast(),
            current: ptr.cast(),
            _owner: Default::default(),
        }
    }
}
impl Iterator for EnvList<'_> {
    type Item = (OsString, OsString);

    fn next(&mut self) -> Option<Self::Item> {
        // SAFETY: We were given a pointer to a valid environment list,
        // and we only ever advance it to the exact end of the list.
        match unsafe { self.current.as_mut() } {
            None => None,
            Some(item) => {
                let ret = item.as_kv();
                // SAFETY: We know we're still pointing to a valid pointer,
                // and advancing it one more is allowed.
                unsafe {
                    self.current = self.current.add(1);
                    ptr::drop_in_place(item as *mut EnvVar);
                }
                Some(ret)
            }
        }
    }
}

impl Drop for EnvList<'_> {
    fn drop(&mut self) {
        // SAFETY: We own self.start, and we know that self.current points to
        // either an item we haven't used, or to the None end.
        unsafe {
            while let Some(var_ref) = self.current.as_mut() {
                ptr::drop_in_place(var_ref as *mut EnvVar);
                self.current = self.current.add(1);
            }
            memory::free(self.start.as_ptr())
        }
    }
}

struct EnvVar(CHeapString);

impl EnvVar {
    fn as_kv(&self) -> (OsString, OsString) {
        let bytes = self.0.to_bytes();
        let mut split = bytes.splitn(2, |&b| b == b'=');
        (
            OsString::from_vec(split.next().unwrap_or_default().into()),
            OsString::from_vec(split.next().unwrap_or_default().into()),
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn os(text: &str) -> OsString {
        OsString::from_vec(text.into())
    }

    #[test]
    fn test_split_kv() {
        fn test(input: &str, key: &str, value: &str) {
            let data = CHeapString::new(input).unwrap();
            let key = os(key);
            let value = os(value);

            assert_eq!(EnvVar(data).as_kv(), (key, value));
        }
        test("THIS=that", "THIS", "that");
        test("THESE=those, no one=knows", "THESE", "those, no one=knows");
        test("HERE=", "HERE", "");
        test("SOME", "SOME", "");
        test("", "", "");
    }

    fn env_list(strings: &[&'static str]) -> EnvList<'static> {
        let ptrs: NonNull<Option<CHeapString>> = memory::calloc(strings.len() + 1).unwrap();
        unsafe {
            for (idx, &text) in strings.iter().enumerate() {
                ptr::write(
                    ptrs.add(idx).as_ptr(),
                    Some(CHeapString::new(text).unwrap()),
                )
            }
            ptr::write(ptrs.add(strings.len()).as_ptr(), None);
            EnvList::from_ptr(ptrs.cast())
        }
    }

    #[test]
    fn test_iter() {
        let envs = env_list(&["ONE=two", "BIRDS=birds=birds", "me", "you="]);
        let result: Vec<_> = envs.collect();
        assert_eq!(
            vec![
                (os("ONE"), os("two")),
                (os("BIRDS"), os("birds=birds")),
                (os("me"), os("")),
                (os("you"), os("")),
            ],
            result
        );
    }

    #[test]
    fn test_iter_partial() {
        let mut envs = env_list(&[
            "iterating=this",
            "also=here",
            "but not=this one",
            "or even=the last",
        ]);

        assert_eq!(Some((os("iterating"), os("this"))), envs.next());
        assert_eq!(Some((os("also"), os("here"))), envs.next());
        // let envs drop
    }
}