diff options
Diffstat (limited to 'rust/kernel/uaccess.rs')
| -rw-r--r-- | rust/kernel/uaccess.rs | 167 | 
1 files changed, 153 insertions, 14 deletions
| diff --git a/rust/kernel/uaccess.rs b/rust/kernel/uaccess.rs index 6d70edd8086a..a8fb4764185a 100644 --- a/rust/kernel/uaccess.rs +++ b/rust/kernel/uaccess.rs @@ -8,14 +8,57 @@ use crate::{      alloc::{Allocator, Flags},      bindings,      error::Result, -    ffi::c_void, +    ffi::{c_char, c_void},      prelude::*,      transmute::{AsBytes, FromBytes},  };  use core::mem::{size_of, MaybeUninit}; -/// The type used for userspace addresses. -pub type UserPtr = usize; +/// A pointer into userspace. +/// +/// This is the Rust equivalent to C pointers tagged with `__user`. +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct UserPtr(*mut c_void); + +impl UserPtr { +    /// Create a `UserPtr` from an integer representing the userspace address. +    #[inline] +    pub fn from_addr(addr: usize) -> Self { +        Self(addr as *mut c_void) +    } + +    /// Create a `UserPtr` from a pointer representing the userspace address. +    #[inline] +    pub fn from_ptr(addr: *mut c_void) -> Self { +        Self(addr) +    } + +    /// Cast this userspace pointer to a raw const void pointer. +    /// +    /// It is up to the caller to use the returned pointer correctly. +    #[inline] +    pub fn as_const_ptr(self) -> *const c_void { +        self.0 +    } + +    /// Cast this userspace pointer to a raw mutable void pointer. +    /// +    /// It is up to the caller to use the returned pointer correctly. +    #[inline] +    pub fn as_mut_ptr(self) -> *mut c_void { +        self.0 +    } + +    /// Increment this user pointer by `add` bytes. +    /// +    /// This addition is wrapping, so wrapping around the address space does not result in a panic +    /// even if `CONFIG_RUST_OVERFLOW_CHECKS` is enabled. +    #[inline] +    pub fn wrapping_byte_add(self, add: usize) -> UserPtr { +        UserPtr(self.0.wrapping_byte_add(add)) +    } +}  /// A pointer to an area in userspace memory, which can be either read-only or read-write.  /// @@ -177,7 +220,7 @@ impl UserSliceReader {      pub fn skip(&mut self, num_skip: usize) -> Result {          // Update `self.length` first since that's the fallible part of this operation.          self.length = self.length.checked_sub(num_skip).ok_or(EFAULT)?; -        self.ptr = self.ptr.wrapping_add(num_skip); +        self.ptr = self.ptr.wrapping_byte_add(num_skip);          Ok(())      } @@ -224,11 +267,11 @@ impl UserSliceReader {          }          // SAFETY: `out_ptr` points into a mutable slice of length `len`, so we may write          // that many bytes to it. -        let res = unsafe { bindings::copy_from_user(out_ptr, self.ptr as *const c_void, len) }; +        let res = unsafe { bindings::copy_from_user(out_ptr, self.ptr.as_const_ptr(), len) };          if res != 0 {              return Err(EFAULT);          } -        self.ptr = self.ptr.wrapping_add(len); +        self.ptr = self.ptr.wrapping_byte_add(len);          self.length -= len;          Ok(())      } @@ -240,7 +283,7 @@ impl UserSliceReader {      pub fn read_slice(&mut self, out: &mut [u8]) -> Result {          // SAFETY: The types are compatible and `read_raw` doesn't write uninitialized bytes to          // `out`. -        let out = unsafe { &mut *(out as *mut [u8] as *mut [MaybeUninit<u8>]) }; +        let out = unsafe { &mut *(core::ptr::from_mut(out) as *mut [MaybeUninit<u8>]) };          self.read_raw(out)      } @@ -262,14 +305,14 @@ impl UserSliceReader {          let res = unsafe {              bindings::_copy_from_user(                  out.as_mut_ptr().cast::<c_void>(), -                self.ptr as *const c_void, +                self.ptr.as_const_ptr(),                  len,              )          };          if res != 0 {              return Err(EFAULT);          } -        self.ptr = self.ptr.wrapping_add(len); +        self.ptr = self.ptr.wrapping_byte_add(len);          self.length -= len;          // SAFETY: The read above has initialized all bytes in `out`, and since `T` implements          // `FromBytes`, any bit-pattern is a valid value for this type. @@ -291,6 +334,65 @@ impl UserSliceReader {          unsafe { buf.inc_len(len) };          Ok(())      } + +    /// Read a NUL-terminated string from userspace and return it. +    /// +    /// The string is read into `buf` and a NUL-terminator is added if the end of `buf` is reached. +    /// Since there must be space to add a NUL-terminator, the buffer must not be empty. The +    /// returned `&CStr` points into `buf`. +    /// +    /// Fails with [`EFAULT`] if the read happens on a bad address (some data may have been +    /// copied). +    #[doc(alias = "strncpy_from_user")] +    pub fn strcpy_into_buf<'buf>(self, buf: &'buf mut [u8]) -> Result<&'buf CStr> { +        if buf.is_empty() { +            return Err(EINVAL); +        } + +        // SAFETY: The types are compatible and `strncpy_from_user` doesn't write uninitialized +        // bytes to `buf`. +        let mut dst = unsafe { &mut *(core::ptr::from_mut(buf) as *mut [MaybeUninit<u8>]) }; + +        // We never read more than `self.length` bytes. +        if dst.len() > self.length { +            dst = &mut dst[..self.length]; +        } + +        let mut len = raw_strncpy_from_user(dst, self.ptr)?; +        if len < dst.len() { +            // Add one to include the NUL-terminator. +            len += 1; +        } else if len < buf.len() { +            // This implies that `len == dst.len() < buf.len()`. +            // +            // This means that we could not fill the entire buffer, but we had to stop reading +            // because we hit the `self.length` limit of this `UserSliceReader`. Since we did not +            // fill the buffer, we treat this case as if we tried to read past the `self.length` +            // limit and received a page fault, which is consistent with other `UserSliceReader` +            // methods that also return page faults when you exceed `self.length`. +            return Err(EFAULT); +        } else { +            // This implies that `len == buf.len()`. +            // +            // This means that we filled the buffer exactly. In this case, we add a NUL-terminator +            // and return it. Unlike the `len < dst.len()` branch, don't modify `len` because it +            // already represents the length including the NUL-terminator. +            // +            // SAFETY: Due to the check at the beginning, the buffer is not empty. +            unsafe { *buf.last_mut().unwrap_unchecked() = 0 }; +        } + +        // This method consumes `self`, so it can only be called once, thus we do not need to +        // update `self.length`. This sidesteps concerns such as whether `self.length` should be +        // incremented by `len` or `len-1` in the `len == buf.len()` case. + +        // SAFETY: There are two cases: +        // * If we hit the `len < dst.len()` case, then `raw_strncpy_from_user` guarantees that +        //   this slice contains exactly one NUL byte at the end of the string. +        // * Otherwise, `raw_strncpy_from_user` guarantees that the string contained no NUL bytes, +        //   and we have since added a NUL byte at the end. +        Ok(unsafe { CStr::from_bytes_with_nul_unchecked(&buf[..len]) }) +    }  }  /// A writer for [`UserSlice`]. @@ -327,11 +429,11 @@ impl UserSliceWriter {          }          // SAFETY: `data_ptr` points into an immutable slice of length `len`, so we may read          // that many bytes from it. -        let res = unsafe { bindings::copy_to_user(self.ptr as *mut c_void, data_ptr, len) }; +        let res = unsafe { bindings::copy_to_user(self.ptr.as_mut_ptr(), data_ptr, len) };          if res != 0 {              return Err(EFAULT);          } -        self.ptr = self.ptr.wrapping_add(len); +        self.ptr = self.ptr.wrapping_byte_add(len);          self.length -= len;          Ok(())      } @@ -354,16 +456,53 @@ impl UserSliceWriter {          // is a compile-time constant.          let res = unsafe {              bindings::_copy_to_user( -                self.ptr as *mut c_void, -                (value as *const T).cast::<c_void>(), +                self.ptr.as_mut_ptr(), +                core::ptr::from_ref(value).cast::<c_void>(),                  len,              )          };          if res != 0 {              return Err(EFAULT);          } -        self.ptr = self.ptr.wrapping_add(len); +        self.ptr = self.ptr.wrapping_byte_add(len);          self.length -= len;          Ok(())      }  } + +/// Reads a nul-terminated string into `dst` and returns the length. +/// +/// This reads from userspace until a NUL byte is encountered, or until `dst.len()` bytes have been +/// read. Fails with [`EFAULT`] if a read happens on a bad address (some data may have been +/// copied). When the end of the buffer is encountered, no NUL byte is added, so the string is +/// *not* guaranteed to be NUL-terminated when `Ok(dst.len())` is returned. +/// +/// # Guarantees +/// +/// When this function returns `Ok(len)`, it is guaranteed that the first `len` bytes of `dst` are +/// initialized and non-zero. Furthermore, if `len < dst.len()`, then `dst[len]` is a NUL byte. +#[inline] +fn raw_strncpy_from_user(dst: &mut [MaybeUninit<u8>], src: UserPtr) -> Result<usize> { +    // CAST: Slice lengths are guaranteed to be `<= isize::MAX`. +    let len = dst.len() as isize; + +    // SAFETY: `dst` is valid for writing `dst.len()` bytes. +    let res = unsafe { +        bindings::strncpy_from_user( +            dst.as_mut_ptr().cast::<c_char>(), +            src.as_const_ptr().cast::<c_char>(), +            len, +        ) +    }; + +    if res < 0 { +        return Err(Error::from_errno(res as i32)); +    } + +    #[cfg(CONFIG_RUST_OVERFLOW_CHECKS)] +    assert!(res <= len); + +    // GUARANTEES: `strncpy_from_user` was successful, so `dst` has contents in accordance with the +    // guarantees of this function. +    Ok(res as usize) +} | 
