On Fri, Nov 22, 2024 at 03:40:33PM +0000, Alice Ryhl wrote: > Introduce a new type called `CurrentTask` that lets you perform various > operations that are only safe on the `current` task. Use the new type to > provide a way to access the current mm without incrementing its > refcount. > > With this change, you can write stuff such as > > let vma = current!().mm().lock_vma_under_rcu(addr); > > without incrementing any refcounts. > > Signed-off-by: Alice Ryhl <aliceryhl@xxxxxxxxxx> > --- > Reviewers: Does accessing task->mm on a non-current task require rcu > protection? > > Christian: If you submit the PidNamespace abstractions this cycle, I'll > update this to also apply to PidNamespace. > --- > rust/kernel/mm.rs | 22 ------------------ > rust/kernel/task.rs | 64 ++++++++++++++++++++++++++++++++++++++++++----------- > 2 files changed, 51 insertions(+), 35 deletions(-) > > diff --git a/rust/kernel/mm.rs b/rust/kernel/mm.rs > index 50f4861ae4b9..f7d1079391ef 100644 > --- a/rust/kernel/mm.rs > +++ b/rust/kernel/mm.rs > @@ -142,28 +142,6 @@ fn deref(&self) -> &MmWithUser { > > // These methods are safe to call even if `mm_users` is zero. > impl Mm { > - /// Call `mmgrab` on `current.mm`. > - #[inline] > - pub fn mmgrab_current() -> Option<ARef<Mm>> { > - // SAFETY: It's safe to get the `mm` field from current. > - let mm = unsafe { > - let current = bindings::get_current(); > - (*current).mm > - }; > - > - if mm.is_null() { > - return None; > - } > - > - // SAFETY: The value of `current->mm` is guaranteed to be null or a valid `mm_struct`. We > - // just checked that it's not null. Furthermore, the returned `&Mm` is valid only for the > - // duration of this function, and `current->mm` will stay valid for that long. > - let mm = unsafe { Mm::from_raw(mm) }; > - > - // This increments the refcount using `mmgrab`. > - Some(ARef::from(mm)) > - } > - > /// Returns a raw pointer to the inner `mm_struct`. > #[inline] > pub fn as_raw(&self) -> *mut bindings::mm_struct { > diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs > index 9e59d86da42d..103d235eb844 100644 > --- a/rust/kernel/task.rs > +++ b/rust/kernel/task.rs > @@ -94,6 +94,26 @@ unsafe impl Send for Task {} > // synchronised by C code (e.g., `signal_pending`). > unsafe impl Sync for Task {} > > +/// Represents a [`Task`] obtained from the `current` global. > +/// > +/// This type exists to provide more efficient operations that are only valid on the current task. > +/// For example, to retrieve the pid-namespace of a task, you must use rcu protection unless it is > +/// the current task. > +/// > +/// # Invariants > +/// > +/// Must be equal to `current` of some thread that is currently running somewhere. > +pub struct CurrentTask(Task); > + I think you need to make `CurrentTask` `!Sync`, right? Otherwise, other threads can access the shared reference of a `CurrentTask` and the ->mm field. I'm thinking if we have a scoped thread/workqueue support in the future: let task = current!(); Workqueue::scoped(|s| { s.spawn(|| { let mm = task.mm(); // do something with the mm }); }); > +// Make all `Task` methods available on `CurrentTask`. > +impl Deref for CurrentTask { > + type Target = Task; > + #[inline] > + fn deref(&self) -> &Task { > + &self.0 > + } > +} > + > /// The type of process identifiers (PIDs). > type Pid = bindings::pid_t; > > @@ -121,27 +141,25 @@ pub fn current_raw() -> *mut bindings::task_struct { > /// # Safety > /// > /// Callers must ensure that the returned object doesn't outlive the current task/thread. > - pub unsafe fn current() -> impl Deref<Target = Task> { > - struct TaskRef<'a> { > - task: &'a Task, > - _not_send: NotThreadSafe, > + pub unsafe fn current() -> impl Deref<Target = CurrentTask> { > + struct TaskRef { > + task: *const CurrentTask, > } > > - impl Deref for TaskRef<'_> { > - type Target = Task; > + impl Deref for TaskRef { > + type Target = CurrentTask; > > fn deref(&self) -> &Self::Target { > - self.task > + // SAFETY: The returned reference borrows from this `TaskRef`, so it cannot outlive > + // the `TaskRef`, which the caller of `Task::current()` has promised will not > + // outlive the task/thread for which `self.task` is the `current` pointer. Thus, it > + // is okay to return a `CurrentTask` reference here. > + unsafe { &*self.task } > } > } > > - let current = Task::current_raw(); > TaskRef { > - // SAFETY: If the current thread is still running, the current task is valid. Given > - // that `TaskRef` is not `Send`, we know it cannot be transferred to another thread > - // (where it could potentially outlive the caller). > - task: unsafe { &*current.cast() }, > - _not_send: NotThreadSafe, > + task: Task::current_raw().cast(), > } > } > > @@ -203,6 +221,26 @@ pub fn wake_up(&self) { > } > } > > +impl CurrentTask { > + /// Access the address space of this task. > + /// > + /// To increment the refcount of the referenced `mm`, you can use `ARef::from`. > + #[inline] > + pub fn mm(&self) -> Option<&MmWithUser> { Hmm... similar issue, `MmWithUser` is `Sync`. > + let mm = unsafe { (*self.as_ptr()).mm }; > + > + if mm.is_null() { > + None > + } else { > + // SAFETY: If `current->mm` is non-null, then it references a valid mm with a non-zero > + // value of `mm_users`. The returned `&MmWithUser` borrows from `CurrentTask`, so the > + // `&MmWithUser` cannot escape the current task, meaning `mm_users` can't reach zero > + // while the reference is still live. Regards, Boqun > + Some(unsafe { MmWithUser::from_raw(mm) }) > + } > + } > +} > + > // SAFETY: The type invariants guarantee that `Task` is always refcounted. > unsafe impl crate::types::AlwaysRefCounted for Task { > fn inc_ref(&self) { > > -- > 2.47.0.371.ga323438b13-goog >