On Fri, May 24, 2019 at 07:09:23PM -0300, Jason Gunthorpe wrote: > On Fri, May 24, 2019 at 02:46:08PM -0400, Jerome Glisse wrote: > > > Here is the big 3 CPU ladder diagram that shows how 'valid' does not > > > work: > > > > > > CPU0 CPU1 CPU2 > > > DEVICE PAGE FAULT > > > range = hmm_range_register() > > > > > > // Overlaps with range > > > hmm_invalidate_start() > > > range->valid = false > > > ops->sync_cpu_device_pagetables() > > > take_lock(driver->update); > > > // Wipe out page tables in device, enable faulting > > > release_lock(driver->update); > > > // Does not overlap with range > > > hmm_invalidate_start() > > > hmm_invalidate_end() > > > list_for_each > > > range->valid = true > > > > ^ > > No this can not happen because CPU0 still has invalidate_range in progress and > > thus hmm->notifiers > 0 so the hmm_invalidate_range_end() will not set the > > range->valid as true. > > Oh, Okay, I now see how this all works, thank you > > > > And I can make this more complicated (ie overlapping parallel > > > invalidates, etc) and show any 'bool' valid cannot work. > > > > It does work. > > Well, I ment the bool alone cannot work, but this is really bool + a > counter. I couldn't shake this unease that bool shouldn't work for this type of locking, especially since odp also used a sequence lock as well as the rwsem... What about this situation: CPU0 CPU1 DEVICE PAGE FAULT range = hmm_range_register() hmm_range_snapshot(&range); // Overlaps with range hmm_invalidate_start() range->valid = false ops->sync_cpu_device_pagetables() take_lock(driver->update); // Wipe out page tables in device, enable faulting release_lock(driver->update); hmm_invalidate_end() range->valid = true take_lock(driver->update); if (!hmm_range_valid(&range)) goto again ESTABLISH SPTES release_lock(driver->update); The ODP patch appears to follow this pattern as the dma_map and the mlx5_ib_update_xlt are in different locking regions. We should dump the result of rmm_range_snapshot in this case, certainly the driver shouldn't be tasked with fixing this.. So, something like this is what I'm thinking about: >From 41b6a6120e30978e7335ada04fb9168db4e5fd29 Mon Sep 17 00:00:00 2001 From: Jason Gunthorpe <jgg@xxxxxxxxxxxx> Date: Mon, 27 May 2019 16:48:53 -0300 Subject: [PATCH] RFC mm/hmm: Replace the range->valid with a seqcount Instead of trying to use a single valid to keep track of parallel invalidations use a traditional seqcount retry lock. Replace the range->valid with the concept of a 'update critical region' bounded by hmm_range_start_update() / hmm_range_end_update() which can fail and need retry if it is ever interrupted by a parallel invalidation. Updaters must create the critical section and can only finish their update while holding the device_lock. Continue to take a very loose approach to track invalidation, now with a single global seqcount for all ranges. This is done to minimize the overhead in the mmu notifier, and expects there will only be a small number of ranges active at once. It could be converted to a seqcount per range if necessary. Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxxxx> --- Documentation/vm/hmm.rst | 22 +++++-------- include/linux/hmm.h | 60 ++++++++++++++++++++++++--------- mm/hmm.c | 71 ++++++++++++++++++++++------------------ 3 files changed, 93 insertions(+), 60 deletions(-) diff --git a/Documentation/vm/hmm.rst b/Documentation/vm/hmm.rst index 7c1e929931a07f..7e827058964579 100644 --- a/Documentation/vm/hmm.rst +++ b/Documentation/vm/hmm.rst @@ -229,32 +229,27 @@ The usage pattern is:: * will use the return value of hmm_range_snapshot() below under the * mmap_sem to ascertain the validity of the range. */ - hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC); - again: + if (!hmm_range_start_update(&range, TIMEOUT_IN_MSEC)) + goto err + down_read(&mm->mmap_sem); ret = hmm_range_snapshot(&range); if (ret) { up_read(&mm->mmap_sem); - if (ret == -EAGAIN) { - /* - * No need to check hmm_range_wait_until_valid() return value - * on retry we will get proper error with hmm_range_snapshot() - */ - hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC); - goto again; - } + if (ret == -EAGAIN) + goto again; hmm_mirror_unregister(&range); return ret; } take_lock(driver->update); - if (!hmm_range_valid(&range)) { + if (!hmm_range_end_update(&range)) { release_lock(driver->update); up_read(&mm->mmap_sem); goto again; } - // Use pfns array content to update device page table + // Use pfns array content to update device page table, must hold driver->update hmm_mirror_unregister(&range); release_lock(driver->update); @@ -264,7 +259,8 @@ The usage pattern is:: The driver->update lock is the same lock that the driver takes inside its sync_cpu_device_pagetables() callback. That lock must be held before calling -hmm_range_valid() to avoid any race with a concurrent CPU page table update. +hmm_range_end_update() to avoid any race with a concurrent CPU page table +update. HMM implements all this on top of the mmu_notifier API because we wanted a simpler API and also to be able to perform optimizations latter on like doing diff --git a/include/linux/hmm.h b/include/linux/hmm.h index 26dfd9377b5094..9096113cfba8de 100644 --- a/include/linux/hmm.h +++ b/include/linux/hmm.h @@ -90,7 +90,9 @@ * @mmu_notifier: mmu notifier to track updates to CPU page table * @mirrors_sem: read/write semaphore protecting the mirrors list * @wq: wait queue for user waiting on a range invalidation - * @notifiers: count of active mmu notifiers + * @active_invalidates: count of active mmu notifier invalidations + * @range_invalidated: seqcount indicating that an active range was + * maybe invalidated */ struct hmm { struct mm_struct *mm; @@ -102,7 +104,8 @@ struct hmm { struct rw_semaphore mirrors_sem; wait_queue_head_t wq; struct rcu_head rcu; - long notifiers; + unsigned int active_invalidates; + seqcount_t range_invalidated; }; /* @@ -169,7 +172,7 @@ enum hmm_pfn_value_e { * @pfn_flags_mask: allows to mask pfn flags so that only default_flags matter * @page_shift: device virtual address shift value (should be >= PAGE_SHIFT) * @pfn_shifts: pfn shift value (should be <= PAGE_SHIFT) - * @valid: pfns array did not change since it has been fill by an HMM function + * @update_seq: sequence number for the seqcount lock read side */ struct hmm_range { struct hmm *hmm; @@ -184,7 +187,7 @@ struct hmm_range { uint64_t pfn_flags_mask; uint8_t page_shift; uint8_t pfn_shift; - bool valid; + unsigned int update_seq; }; /* @@ -208,27 +211,52 @@ static inline unsigned long hmm_range_page_size(const struct hmm_range *range) } /* - * hmm_range_wait_until_valid() - wait for range to be valid + * hmm_range_start_update() - wait for range to be valid * @range: range affected by invalidation to wait on * @timeout: time out for wait in ms (ie abort wait after that period of time) * Return: true if the range is valid, false otherwise. */ -static inline bool hmm_range_wait_until_valid(struct hmm_range *range, - unsigned long timeout) +// FIXME: hmm should handle the timeout for the driver too. +static inline unsigned int hmm_range_start_update(struct hmm_range *range, + unsigned long timeout) { - wait_event_timeout(range->hmm->wq, range->valid, + wait_event_timeout(range->hmm->wq, + READ_ONCE(range->hmm->active_invalidates) == 0, msecs_to_jiffies(timeout)); - return READ_ONCE(range->valid); + + // FIXME: wants a non-raw seq helper + seqcount_lockdep_reader_access(&range->hmm->range_invalidated); + range->update_seq = raw_seqcount_begin(&range->hmm->range_invalidated); + return !read_seqcount_retry(&range->hmm->range_invalidated, + range->update_seq); } /* - * hmm_range_valid() - test if a range is valid or not + * hmm_range_needs_retry() - test if a range that has begun update + * via hmm_range_start_update() needs to be retried. * @range: range - * Return: true if the range is valid, false otherwise. + * Return: true if the update needs to be restarted from hmm_range_start_update(), + * false otherwise. + */ +static inline bool hmm_range_needs_retry(struct hmm_range *range) +{ + return read_seqcount_retry(&range->hmm->range_invalidated, + range->update_seq); +} + +/* + * hmm_range_end_update() - finish an update of a range + * @range: range + * + * This must only be called while holding the device lock as described in + * hmm.rst, and must be called before making any of the update visible. + * + * Return: true if the update was not interrupted by an invalidation of the + * covered virtual memory range, false if the update needs to be retried. */ -static inline bool hmm_range_valid(struct hmm_range *range) +static inline bool hmm_range_end_update(struct hmm_range *range) { - return range->valid; + return !hmm_range_needs_retry(range); } /* @@ -511,7 +539,7 @@ static inline int hmm_mirror_range_register(struct hmm_range *range, /* This is a temporary helper to avoid merge conflict between trees. */ static inline bool hmm_vma_range_done(struct hmm_range *range) { - bool ret = hmm_range_valid(range); + bool ret = !hmm_range_needs_retry(range); hmm_range_unregister(range); return ret; @@ -537,7 +565,7 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block) if (ret) return (int)ret; - if (!hmm_range_wait_until_valid(range, HMM_RANGE_DEFAULT_TIMEOUT)) { + if (!hmm_range_start_update(range, HMM_RANGE_DEFAULT_TIMEOUT)) { hmm_range_unregister(range); /* * The mmap_sem was taken by driver we release it here and @@ -549,6 +577,8 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block) } ret = hmm_range_fault(range, block); + if (!hmm_range_end_update(range)) + ret = -EAGAIN; if (ret <= 0) { hmm_range_unregister(range); if (ret == -EBUSY || !ret) { diff --git a/mm/hmm.c b/mm/hmm.c index 8396a65710e304..09725774ff6112 100644 --- a/mm/hmm.c +++ b/mm/hmm.c @@ -79,7 +79,8 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm) INIT_LIST_HEAD(&hmm->ranges); mutex_init(&hmm->lock); kref_init(&hmm->kref); - hmm->notifiers = 0; + hmm->active_invalidates = 0; + seqcount_init(&hmm->range_invalidated); hmm->mm = mm; mmgrab(mm); @@ -155,13 +156,22 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) hmm_put(hmm); } +static bool any_range_overlaps(struct hmm *hmm, unsigned long start, unsigned long end) +{ + struct hmm_range *range; + + list_for_each_entry(range, &hmm->ranges, list) + // FIXME: check me + if (range->start <= end && range->end < start) + return true; + return false; +} static int hmm_invalidate_range_start(struct mmu_notifier *mn, const struct mmu_notifier_range *nrange) { struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); struct hmm_mirror *mirror; struct hmm_update update; - struct hmm_range *range; int ret = 0; /* hmm is in progress to free */ @@ -179,13 +189,22 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ret = -EAGAIN; goto out; } - hmm->notifiers++; - list_for_each_entry(range, &hmm->ranges, list) { - if (update.end < range->start || update.start >= range->end) - continue; - - range->valid = false; - } + /* + * The seqcount is used as a fast but inaccurate indication that + * another CPU working with a range needs to retry due to invalidation + * of the same virtual address space covered by the range by this + * CPU. + * + * It is based on the assumption that the ranges will be short lived, + * so there is no need to be aggressively accurate in the mmu notifier + * fast path. Any notifier intersection will cause all registered + * ranges to retry. + */ + hmm->active_invalidates++; + // FIXME: needs a seqcount helper + if (!(hmm->range_invalidated.sequence & 1) && + any_range_overlaps(hmm, update.start, update.end)) + write_seqcount_begin(&hmm->range_invalidated); mutex_unlock(&hmm->lock); if (mmu_notifier_range_blockable(nrange)) @@ -218,15 +237,11 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn, return; mutex_lock(&hmm->lock); - hmm->notifiers--; - if (!hmm->notifiers) { - struct hmm_range *range; - - list_for_each_entry(range, &hmm->ranges, list) { - if (range->valid) - continue; - range->valid = true; - } + hmm->active_invalidates--; + if (hmm->active_invalidates == 0) { + // FIXME: needs a seqcount helper + if (hmm->range_invalidated.sequence & 1) + write_seqcount_end(&hmm->range_invalidated); wake_up_all(&hmm->wq); } mutex_unlock(&hmm->lock); @@ -882,7 +897,7 @@ int hmm_range_register(struct hmm_range *range, { unsigned long mask = ((1UL << page_shift) - 1UL); - range->valid = false; + range->update_seq = 0; range->hmm = NULL; if ((start & mask) || (end & mask)) @@ -908,15 +923,7 @@ int hmm_range_register(struct hmm_range *range, /* Initialize range to track CPU page table updates. */ mutex_lock(&range->hmm->lock); - list_add(&range->list, &range->hmm->ranges); - - /* - * If there are any concurrent notifiers we have to wait for them for - * the range to be valid (see hmm_range_wait_until_valid()). - */ - if (!range->hmm->notifiers) - range->valid = true; mutex_unlock(&range->hmm->lock); return 0; @@ -947,7 +954,6 @@ void hmm_range_unregister(struct hmm_range *range) hmm_put(hmm); /* The range is now invalid, leave it poisoned. */ - range->valid = false; memset(&range->hmm, POISON_INUSE, sizeof(range->hmm)); } EXPORT_SYMBOL(hmm_range_unregister); @@ -981,7 +987,7 @@ long hmm_range_snapshot(struct hmm_range *range) do { /* If range is no longer valid force retry. */ - if (!range->valid) + if (hmm_range_needs_retry(range)) return -EAGAIN; vma = find_vma(hmm->mm, start); @@ -1080,7 +1086,7 @@ long hmm_range_fault(struct hmm_range *range, bool block) do { /* If range is no longer valid force retry. */ - if (!range->valid) { + if (hmm_range_needs_retry(range)) { up_read(&hmm->mm->mmap_sem); return -EAGAIN; } @@ -1134,7 +1140,7 @@ long hmm_range_fault(struct hmm_range *range, bool block) start = hmm_vma_walk.last; /* Keep trying while the range is valid. */ - } while (ret == -EBUSY && range->valid); + } while (ret == -EBUSY && !hmm_range_needs_retry(range)); if (ret) { unsigned long i; @@ -1195,7 +1201,8 @@ long hmm_range_dma_map(struct hmm_range *range, continue; /* Check if range is being invalidated */ - if (!range->valid) { + if (hmm_range_needs_retry(range)) { + // ?? EAGAIN?? ret = -EBUSY; goto unmap; } -- 2.21.0