On Thu, Jan 20, 2022 at 04:55:22PM +0100, Peter Zijlstra wrote: > +/* > + * Pinning a page inhibits rmap based unmap for Anon pages. Doing a load > + * through the user mapping ensures the user mapping exists. > + */ > +#define umcg_pin_and_load(_self, _pagep, _member) \ > +({ \ > + __label__ __out; \ > + int __ret = -EFAULT; \ > + \ > + if (pin_user_pages_fast((unsigned long)(_self), 1, 0, &(_pagep)) != 1) \ > + goto __out; \ > + \ > + if (!PageAnon(_pagep) || \ > + get_user(_member, &(_self)->_member)) { \ > + unpin_user_page(_pagep); \ > + goto __out; \ > + } \ > + __ret = 0; \ > +__out: __ret; \ > +}) Per the thread with David, this wants changing like so. --- --- a/kernel/sched/umcg.c +++ b/kernel/sched/umcg.c @@ -34,25 +34,26 @@ static struct task_struct *umcg_get_task } /* - * Pinning a page inhibits rmap based unmap for Anon pages. Doing a load - * through the user mapping ensures the user mapping exists. + * Pinning a page inhibits rmap based unmap for Anon pages. Doing a store + * through the user mapping ensures the user mapping exists and is writable. */ -#define umcg_pin_and_load(_self, _pagep, _member) \ -({ \ - __label__ __out; \ - int __ret = -EFAULT; \ - \ - if (pin_user_pages_fast((unsigned long)(_self), 1, 0, &(_pagep)) != 1) \ - goto __out; \ - \ - if (!PageAnon(_pagep) || \ - get_user(_member, &(_self)->_member)) { \ - unpin_user_page(_pagep); \ - goto __out; \ - } \ - __ret = 0; \ -__out: __ret; \ -}) +static int umcg_pin_page(struct umcg_task __user *self, struct page **pagep) +{ + int ret = -EFAULT; + + if (pin_user_pages_fast((unsigned long)self, 1, FOLL_WRITE, pagep) != 1) + goto out; + + if (!PageAnon(*pagep) || + put_user(0ULL, &self->__zero[0])) { + unpin_user_page(*pagep); + goto out; + } + + ret = 0; +out: + return ret; +} /** * umcg_pin_pages: pin pages containing struct umcg_task of @@ -72,10 +73,13 @@ static int umcg_pin_pages(void) tsk->umcg_server)) return -EBUSY; - ret = umcg_pin_and_load(self, tsk->umcg_page, server_tid); + ret = umcg_pin_page(self, &tsk->umcg_page); if (ret) goto clear_self; + if (get_user(server_tid, &self->server_tid)) + goto unpin_self; + ret = -ESRCH; server = umcg_get_task(server_tid); if (!server) @@ -83,7 +87,7 @@ static int umcg_pin_pages(void) /* must cache due to possible concurrent change */ tsk->umcg_server_task = READ_ONCE(server->umcg_task); - ret = umcg_pin_and_load(tsk->umcg_server_task, tsk->umcg_server_page, server_tid); + ret = umcg_pin_page(tsk->umcg_server_task, &tsk->umcg_server_page); if (ret) goto clear_server; @@ -414,7 +418,7 @@ static int umcg_wait(u64 timo) break; } - ret = umcg_pin_and_load(self, page, state); + ret = umcg_pin_page(self, &page); if (ret) { page = NULL; break;