Bernard Metzler <bmt@xxxxxxxxxxxxxx> writes: > To avoid racing with other user memory reservations, immediately > account full amount of pages to be pinned. > > Fixes: 2251334dcac9 ("rdma/siw: application buffer management") > Reported-by: Jason Gunthorpe <jgg@xxxxxxxxxx> > Suggested-by: Alistair Popple <apopple@xxxxxxxxxx> > Signed-off-by: Bernard Metzler <bmt@xxxxxxxxxxxxxx> > --- > drivers/infiniband/sw/siw/siw_mem.c | 7 +++++-- > 1 file changed, 5 insertions(+), 2 deletions(-) > > diff --git a/drivers/infiniband/sw/siw/siw_mem.c b/drivers/infiniband/sw/siw/siw_mem.c > index b2b33dd3b4fa..7afdbe3f2266 100644 > --- a/drivers/infiniband/sw/siw/siw_mem.c > +++ b/drivers/infiniband/sw/siw/siw_mem.c > @@ -398,7 +398,7 @@ struct siw_umem *siw_umem_get(u64 start, u64 len, bool writable) > > mlock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; > > - if (num_pages + atomic64_read(&mm_s->pinned_vm) > mlock_limit) { > + if (atomic64_add_return(num_pages, &mm_s->pinned_vm) > mlock_limit) { > rv = -ENOMEM; > goto out_sem_up; > } > @@ -429,7 +429,6 @@ struct siw_umem *siw_umem_get(u64 start, u64 len, bool writable) > goto out_sem_up; > > umem->num_pages += rv; > - atomic64_add(rv, &mm_s->pinned_vm); > first_page_va += rv * PAGE_SIZE; > nents -= rv; > got += rv; > @@ -442,6 +441,10 @@ struct siw_umem *siw_umem_get(u64 start, u64 len, bool writable) > if (rv > 0) > return umem; > > + /* Adjust accounting for pages not pinned */ > + if (num_pages) > + atomic64_sub(num_pages, &mm_s->pinned_vm); > + > siw_umem_release(umem, false); Won't this unaccount some pages twice if we bail out of this loop early: while (nents) { struct page **plist = &umem->page_chunk[i].plist[got]; rv = pin_user_pages(first_page_va, nents, foll_flags | FOLL_LONGTERM, plist, NULL); if (rv < 0) goto out_sem_up; umem->num_pages += rv; first_page_va += rv * PAGE_SIZE; nents -= rv; got += rv; } num_pages -= got; Because siw_umem_release() will subtract umem->num_pages but num_pages won't always have been updated? Looks like you could just update num_pages in the inner loop and eliminate the `got` variable right? > return ERR_PTR(rv);