diff src/module.rs @ 51:9d1160b02d2c

Safety and doc fixes: - Don't panic when given a string with a null character; instead return `PAM_CONV_ERR`. - Improve pattern matching and use ?s where appropriate. - Format etc.
author Paul Fisher <paul@pfish.zone>
date Sat, 03 May 2025 18:41:25 -0400
parents a921b72743e4
children
line wrap: on
line diff
--- a/src/module.rs	Wed Apr 16 16:55:59 2025 -0400
+++ b/src/module.rs	Sat May 03 18:41:25 2025 -0400
@@ -1,11 +1,10 @@
 //! Functions for use in pam modules.
 
+use crate::constants::{PamFlag, PamResultCode};
+use crate::items::{Item, ItemType};
 use libc::c_char;
 use std::ffi::{CStr, CString};
 
-use crate::constants::{PamFlag, PamResultCode};
-use crate::items::ItemType;
-
 /// Opaque type, used as a pointer when making pam API calls.
 ///
 /// A module is invoked via an external function such as `pam_sm_authenticate`.
@@ -85,21 +84,23 @@
     ///
     /// The data stored under the provided key must be of type `T` otherwise the
     /// behaviour of this function is undefined.
-    pub unsafe fn get_data<T>(&self, key: &str) -> PamResult<&T> {
-        let c_key = CString::new(key).unwrap();
+    ///
+    /// The data, if present, is owned by the current PAM conversation.
+    pub unsafe fn get_data<T>(&self, key: &str) -> PamResult<Option<&T>> {
+        let c_key = CString::new(key).map_err(|_| PamResultCode::PAM_CONV_ERR)?;
         let mut ptr: *const libc::c_void = std::ptr::null();
-        let res = pam_get_data(self, c_key.as_ptr(), &mut ptr);
-        if PamResultCode::PAM_SUCCESS == res && !ptr.is_null() {
-            let typed_ptr = ptr.cast::<T>();
-            let data: &T = &*typed_ptr;
-            Ok(data)
-        } else {
-            Err(res)
+        to_result(pam_get_data(self, c_key.as_ptr(), &mut ptr))?;
+        match ptr.is_null() {
+            true => Ok(None),
+            false => {
+                let typed_ptr = ptr.cast::<T>();
+                Ok(Some(&*typed_ptr))
+            }
         }
     }
 
-    /// Stores a value that can be retrieved later with `get_data`.  The value lives
-    /// as long as the current pam cycle.
+    /// Stores a value that can be retrieved later with `get_data`.
+    /// The conversation takes ownership of the data.
     ///
     /// See the [`pam_set_data` manual page](
     /// https://www.man7.org/linux/man-pages/man3/pam_set_data.3.html).
@@ -107,8 +108,8 @@
     /// # Errors
     ///
     /// Returns an error if the underlying PAM function call fails.
-    pub fn set_data<T>(&self, key: &str, data: Box<T>) -> PamResult<()> {
-        let c_key = CString::new(key).unwrap();
+    pub fn set_data<T>(&mut self, key: &str, data: Box<T>) -> PamResult<()> {
+        let c_key = CString::new(key).map_err(|_| PamResultCode::PAM_CONV_ERR)?;
         let res = unsafe {
             pam_set_data(
                 self,
@@ -120,8 +121,11 @@
         to_result(res)
     }
 
-    /// Retrieves a value that has been set, possibly by the pam client.  This is
-    /// particularly useful for getting a `PamConv` reference.
+    /// Retrieves a value that has been set, possibly by the pam client.
+    /// This is particularly useful for getting a `PamConv` reference.
+    ///
+    /// These items are *references to PAM memory*
+    /// which are *owned by the conversation*.
     ///
     /// See the [`pam_get_item` manual page](
     /// https://www.man7.org/linux/man-pages/man3/pam_get_item.3.html).
@@ -131,26 +135,19 @@
     /// Returns an error if the underlying PAM function call fails.
     pub fn get_item<T: crate::items::Item>(&self) -> PamResult<Option<T>> {
         let mut ptr: *const libc::c_void = std::ptr::null();
-        let (res, item) = unsafe {
+        let out = unsafe {
             let r = pam_get_item(self, T::type_id(), &mut ptr);
+            to_result(r)?;
             let typed_ptr = ptr.cast::<T::Raw>();
-            let t = if typed_ptr.is_null() {
-                None
-            } else {
-                Some(T::from_raw(typed_ptr))
-            };
-            (r, t)
+            match typed_ptr.is_null() {
+                true => None,
+                false => Some(T::from_raw(typed_ptr)),
+            }
         };
-        match res {
-            PamResultCode::PAM_SUCCESS => Ok(item),
-            other => Err(other),
-        }
+        Ok(out)
     }
 
-    /// Sets a value in the pam context. The value can be retrieved using
-    /// `get_item`.
-    ///
-    /// Note that all items are strings, except `PAM_CONV` and `PAM_FAIL_DELAY`.
+    /// Sets an item in the pam context. It can be retrieved using `get_item`.
     ///
     /// See the [`pam_set_item` manual page](
     /// https://www.man7.org/linux/man-pages/man3/pam_set_item.3.html).
@@ -158,11 +155,7 @@
     /// # Errors
     ///
     /// Returns an error if the underlying PAM function call fails.
-    ///
-    /// # Panics
-    ///
-    /// Panics if the provided item key contains a nul byte.
-    pub fn set_item_str<T: crate::items::Item>(&mut self, item: T) -> PamResult<()> {
+    pub fn set_item<T: Item>(&mut self, item: T) -> PamResult<()> {
         let res =
             unsafe { pam_set_item(self, T::type_id(), item.into_raw().cast::<libc::c_void>()) };
         to_result(res)
@@ -178,21 +171,10 @@
     /// # Errors
     ///
     /// Returns an error if the underlying PAM function call fails.
-    ///
-    /// # Panics
-    ///
-    /// Panics if the provided prompt string contains a nul byte.
     pub fn get_user(&self, prompt: Option<&str>) -> PamResult<String> {
-        let prompt_string;
-        let c_prompt = match prompt {
-            Some(p) => {
-                prompt_string = CString::new(p).unwrap();
-                prompt_string.as_ptr()
-            }
-            None => std::ptr::null(),
-        };
+        let prompt = option_cstr(prompt)?;
         let output: *mut c_char = std::ptr::null_mut();
-        let res = unsafe { pam_get_user(self, &output, c_prompt) };
+        let res = unsafe { pam_get_user(self, &output, prompt_ptr(prompt.as_ref())) };
         match res {
             PamResultCode::PAM_SUCCESS => copy_pam_string(output),
             otherwise => Err(otherwise),
@@ -209,25 +191,35 @@
     /// # Errors
     ///
     /// Returns an error if the underlying PAM function call fails.
-    ///
-    /// # Panics
-    ///
-    /// Panics if the provided prompt string contains a nul byte.
     pub fn get_authtok(&self, prompt: Option<&str>) -> PamResult<String> {
-        let prompt_string;
-        let c_prompt = match prompt {
-            Some(p) => {
-                prompt_string = CString::new(p).unwrap();
-                prompt_string.as_ptr()
-            }
-            None => std::ptr::null(),
+        let prompt = option_cstr(prompt)?;
+        let output: *mut c_char = std::ptr::null_mut();
+        let res = unsafe {
+            pam_get_authtok(
+                self,
+                ItemType::AuthTok,
+                &output,
+                prompt_ptr(prompt.as_ref()),
+            )
         };
-        let output: *mut c_char = std::ptr::null_mut();
-        let res = unsafe { pam_get_authtok(self, ItemType::AuthTok, &output, c_prompt) };
-        match res {
-            PamResultCode::PAM_SUCCESS => copy_pam_string(output),
-            otherwise => Err(otherwise),
-        }
+        to_result(res)?;
+        copy_pam_string(output)
+    }
+}
+
+/// Safely converts a `&str` option to a `CString` option.
+fn option_cstr(prompt: Option<&str>) -> PamResult<Option<CString>> {
+    prompt
+        .map(CString::new)
+        .transpose()
+        .map_err(|_| PamResultCode::PAM_CONV_ERR)
+}
+
+/// The pointer to the prompt CString, or null if absent.
+fn prompt_ptr(prompt: Option<&CString>) -> *const c_char {
+    match prompt {
+        Some(c_str) => c_str.as_ptr(),
+        None => std::ptr::null(),
     }
 }
 
@@ -236,10 +228,13 @@
 fn copy_pam_string(result_ptr: *const c_char) -> PamResult<String> {
     // We really shouldn't get a null pointer back here, but if we do, return nothing.
     if result_ptr.is_null() {
-        return Ok(String::from(""));
+        return Ok(String::new());
     }
-    let bytes = unsafe { CStr::from_ptr(result_ptr).to_bytes() };
-    String::from_utf8(bytes.to_vec()).map_err(|_| PamResultCode::PAM_CONV_ERR)
+    let bytes = unsafe { CStr::from_ptr(result_ptr) };
+    Ok(bytes
+        .to_str()
+        .map_err(|_| PamResultCode::PAM_CONV_ERR)?
+        .into())
 }
 
 /// Convenience to transform a `PamResultCode` into a unit `PamResult`.