On Fri, Jun 17, 2016 at 11:00:17AM +0200, Michal Hocko wrote: > From: Michal Hocko <mhocko@xxxxxxxx> > > vhost driver relies on copy_from_user/get_user from a kernel thread. > This makes it impossible to reap the memory of an oom victim which > shares mm with the vhost kernel thread because it could see a zero > page unexpectedly and theoretically make an incorrect decision visible > outside of the killed task context. > > Make sure that each place which can read from userspace is annotated > properly and it uses copy_from_user_mm, __get_user_mm resp. > copy_from_iter_mm. Each will get the target mm as an argument and it > performs a pessimistic check to rule out that the oom_reaper could > possibly unmap the particular page. __oom_reap_task then just needs to > mark the mm as unstable before it unmaps any page. > > This is a preparatory patch without any functional changes because > the oom reaper doesn't touch mm shared with kthreads yet. > > Signed-off-by: Michal Hocko <mhocko@xxxxxxxx> Will review. Answer to question below: > --- > Hi Michael, > we have discussed [1] that vhost_worker pins the mm of a potential > oom victim for too long which result into an OOM storm when other > processes have to be killed. One way to address this issue would > be to pin mm_count rather than mm_users and revalidate it before > actually doing the copy (mmget_not_zero). You had concerns about > more atomic operations in the data path. Another way would be to > postpone exit_mm_victim to after exit_task_work but as it turned > out other task might have the device open and pin the mm indirectly > [2]. > > Now I would like to attack the issue from another side which would > be more generic. I would like to make mm's which are shared with > kthreads oom reapable in general. This is currently not allowed > because we do not want to risk that a kthread would see an already > unmapped page - aka see a newly allocated zero page. At the same > time this is really desirable because it helps to guarantee a forward > progress on the OOM. > > It seems that vhost usage would suffer from this problem because > it reads from the userspace to get (status) flags and makes some > decisions based on the read value. I do not understand the code so I > couldn't evaluate whether that would lead to some real problems so I > conservatively assumed it wouldn't handle that gracefully. Getting an error from __get_user and friends is handled gracefully. Getting zero instead of a real value will cause userspace memory corruption. > If this is > incorrect and all the paths can just cope with seeing zeros unexpectedly > then great and I will drop the patch and move over to the oom specific > further steps. > > Therefore I am proposing a kthread safe API which allows to read from > userspace and also makes sure to do a proper exclusion with the oom > reaper. A race would be reported by EFAULT which is already handled. > Performance wise it would add two tests to the copy from user > paths. Does the following change makes sense to you and would be > acceptable? If yes I will follow up with another patch which will allow > oom reaper for mm shared with kthread. > > Thanks! > > [1] http://lkml.kernel.org/r/1456765329-14890-1-git-send-email-vdavydov@xxxxxxxxxxxxx > [2] http://lkml.kernel.org/r/20160301181136-mutt-send-email-mst@xxxxxxxxxx > > drivers/vhost/scsi.c | 2 +- > drivers/vhost/vhost.c | 18 +++++++++--------- > include/linux/sched.h | 1 + > include/linux/uaccess.h | 22 ++++++++++++++++++++++ > include/linux/uio.h | 10 ++++++++++ > mm/oom_kill.c | 6 ++++++ > 6 files changed, 49 insertions(+), 10 deletions(-) > > diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c > index 0e6fd556c982..2c8dc0b9a21f 100644 > --- a/drivers/vhost/scsi.c > +++ b/drivers/vhost/scsi.c > @@ -932,7 +932,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) > */ > iov_iter_init(&out_iter, WRITE, vq->iov, out, out_size); > > - ret = copy_from_iter(req, req_size, &out_iter); > + ret = copy_from_iter_mm(vq->dev->mm, req, req_size, &out_iter); > if (unlikely(ret != req_size)) { > vq_err(vq, "Faulted on copy_from_iter\n"); > vhost_scsi_send_bad_target(vs, vq, head, out); > diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c > index 669fef1e2bb6..14959ba43cb4 100644 > --- a/drivers/vhost/vhost.c > +++ b/drivers/vhost/vhost.c > @@ -1212,7 +1212,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) > r = -EFAULT; > goto err; > } > - r = __get_user(last_used_idx, &vq->used->idx); > + r = __get_user_mm(vq->dev->mm, last_used_idx, &vq->used->idx); > if (r) > goto err; > vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); > @@ -1328,7 +1328,7 @@ static int get_indirect(struct vhost_virtqueue *vq, > i, count); > return -EINVAL; > } > - if (unlikely(copy_from_iter(&desc, sizeof(desc), &from) != > + if (unlikely(copy_from_iter_mm(vq->dev->mm, &desc, sizeof(desc), &from) != > sizeof(desc))) { > vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", > i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); > @@ -1392,7 +1392,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, > > /* Check it isn't doing very strange things with descriptor numbers. */ > last_avail_idx = vq->last_avail_idx; > - if (unlikely(__get_user(avail_idx, &vq->avail->idx))) { > + if (unlikely(__get_user_mm(vq->dev->mm, avail_idx, &vq->avail->idx))) { > vq_err(vq, "Failed to access avail idx at %p\n", > &vq->avail->idx); > return -EFAULT; > @@ -1414,7 +1414,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, > > /* Grab the next descriptor number they're advertising, and increment > * the index we've seen. */ > - if (unlikely(__get_user(ring_head, > + if (unlikely(__get_user_mm(vq->dev->mm, ring_head, > &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { > vq_err(vq, "Failed to read head: idx %d address %p\n", > last_avail_idx, > @@ -1450,7 +1450,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, > i, vq->num, head); > return -EINVAL; > } > - ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); > + ret = __copy_from_user_mm(vq->dev->mm, &desc, vq->desc + i, sizeof desc); > if (unlikely(ret)) { > vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", > i, vq->desc + i); > @@ -1622,7 +1622,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) > > if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { > __virtio16 flags; > - if (__get_user(flags, &vq->avail->flags)) { > + if (__get_user_mm(dev->mm, flags, &vq->avail->flags)) { > vq_err(vq, "Failed to get flags"); > return true; > } > @@ -1636,7 +1636,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) > if (unlikely(!v)) > return true; > > - if (__get_user(event, vhost_used_event(vq))) { > + if (__get_user_mm(dev->mm, event, vhost_used_event(vq))) { > vq_err(vq, "Failed to get used event idx"); > return true; > } > @@ -1678,7 +1678,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) > __virtio16 avail_idx; > int r; > > - r = __get_user(avail_idx, &vq->avail->idx); > + r = __get_user_mm(dev->mm, avail_idx, &vq->avail->idx); > if (r) > return false; > > @@ -1713,7 +1713,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) > /* They could have slipped one in as we were doing that: make > * sure it's written, then check again. */ > smp_mb(); > - r = __get_user(avail_idx, &vq->avail->idx); > + r = __get_user_mm(dev->mm,avail_idx, &vq->avail->idx); space after , pls > if (r) { > vq_err(vq, "Failed to check avail idx at %p: %d\n", > &vq->avail->idx, r); > diff --git a/include/linux/sched.h b/include/linux/sched.h > index 6d81a1eb974a..2b00ac7faa18 100644 > --- a/include/linux/sched.h > +++ b/include/linux/sched.h > @@ -513,6 +513,7 @@ static inline int get_dumpable(struct mm_struct *mm) > #define MMF_RECALC_UPROBES 20 /* MMF_HAS_UPROBES can be wrong */ > #define MMF_OOM_REAPED 21 /* mm has been already reaped */ > #define MMF_OOM_NOT_REAPABLE 22 /* mm couldn't be reaped */ > +#define MMF_UNSTABLE 23 /* mm is unstable for copy_from_user */ > > #define MMF_INIT_MASK (MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK) > > diff --git a/include/linux/uaccess.h b/include/linux/uaccess.h > index 349557825428..b1f314fca3c8 100644 > --- a/include/linux/uaccess.h > +++ b/include/linux/uaccess.h > @@ -76,6 +76,28 @@ static inline unsigned long __copy_from_user_nocache(void *to, > #endif /* ARCH_HAS_NOCACHE_UACCESS */ > > /* > + * A safe variant of __get_user for for use_mm() users to have a > + * gurantee that the address space wasn't reaped in the background > + */ > +#define __get_user_mm(mm, x, ptr) \ > +({ \ > + int ___gu_err = __get_user(x, ptr); \ > + if (!___gu_err && test_bit(MMF_UNSTABLE, &mm->flags)) \ test_bit is somewhat expensive. See my old mail x86/bitops: implement __test_bit I dropped it as virtio just switched to simple &/| for features, but we might need something like this now. > + ___gu_err = -EFAULT; \ > + ___gu_err; \ > +}) > + > +/* similar to __get_user_mm */ > +static inline __must_check long __copy_from_user_mm(struct mm_struct *mm, > + void *to, const void __user * from, unsigned long n) > +{ > + long ret = __copy_from_user(to, from, n); > + if (!ret && test_bit(MMF_UNSTABLE, &mm->flags)) > + return -EFAULT; > + return ret; > +} > + > +/* > * probe_kernel_read(): safely attempt to read from a location > * @dst: pointer to the buffer that shall take the data > * @src: address to read from > diff --git a/include/linux/uio.h b/include/linux/uio.h > index 1b5d1cd796e2..4be6b24003d8 100644 > --- a/include/linux/uio.h > +++ b/include/linux/uio.h > @@ -9,6 +9,7 @@ > #ifndef __LINUX_UIO_H > #define __LINUX_UIO_H > > +#include <linux/sched.h> > #include <linux/kernel.h> > #include <uapi/linux/uio.h> > > @@ -84,6 +85,15 @@ size_t copy_page_from_iter(struct page *page, size_t offset, size_t bytes, > struct iov_iter *i); > size_t copy_to_iter(const void *addr, size_t bytes, struct iov_iter *i); > size_t copy_from_iter(void *addr, size_t bytes, struct iov_iter *i); > + > +static inline size_t copy_from_iter_mm(struct mm_struct *mm, void *addr, > + size_t bytes, struct iov_iter *i) > +{ > + size_t ret = copy_from_iter(addr, bytes, i); > + if (!IS_ERR_VALUE(ret) && test_bit(MMF_UNSTABLE, &mm->flags)) > + return -EFAULT; > + return ret; > +} > size_t copy_from_iter_nocache(void *addr, size_t bytes, struct iov_iter *i); > size_t iov_iter_zero(size_t bytes, struct iov_iter *); > unsigned long iov_iter_alignment(const struct iov_iter *i); > diff --git a/mm/oom_kill.c b/mm/oom_kill.c > index 6303bc7caeda..3fa43e96a59b 100644 > --- a/mm/oom_kill.c > +++ b/mm/oom_kill.c > @@ -506,6 +506,12 @@ static bool __oom_reap_task(struct task_struct *tsk) > goto mm_drop; > } > > + /* > + * Tell all users of get_user_mm/copy_from_user_mm that the content > + * is no longer stable. > + */ > + set_bit(MMF_UNSTABLE, &mm->flags); > + do we need some kind of barrier after this? and if yes - does flag read need a barrier before it too? > tlb_gather_mmu(&tlb, mm, 0, -1); > for (vma = mm->mmap ; vma; vma = vma->vm_next) { > if (is_vm_hugetlb_page(vma)) > -- > 2.8.1 -- To unsubscribe, send a message with 'unsubscribe linux-mm' in the body to majordomo@xxxxxxxxx. For more info on Linux MM, see: http://www.linux-mm.org/ . Don't email: <a href=mailto:"dont@xxxxxxxxx"> email@xxxxxxxxx </a>