From: Jason Gunthorpe <jgg@xxxxxxxxxxxx> Ralph observes that hmm_register_range() can only be called by a driver while a mirror is registered. Make this clear in the API by passing in the mirror structure as a parameter. This also simplifies understanding the lifetime model for struct hmm, as the hmm pointer must be valid as part of a registered mirror so all we need in hmm_register_range() is a simple kref_get. Suggested-by: Ralph Campbell <rcampbell@xxxxxxxxxx> Signed-off-by: Jason Gunthorpe <jgg@xxxxxxxxxxxx> --- include/linux/hmm.h | 7 ++++--- mm/hmm.c | 14 +++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/include/linux/hmm.h b/include/linux/hmm.h index 8b91c90d3b88cb..87d29e085a69f7 100644 --- a/include/linux/hmm.h +++ b/include/linux/hmm.h @@ -503,7 +503,7 @@ static inline bool hmm_mirror_mm_is_alive(struct hmm_mirror *mirror) * Please see Documentation/vm/hmm.rst for how to use the range API. */ int hmm_range_register(struct hmm_range *range, - struct mm_struct *mm, + struct hmm_mirror *mirror, unsigned long start, unsigned long end, unsigned page_shift); @@ -539,7 +539,8 @@ static inline bool hmm_vma_range_done(struct hmm_range *range) } /* This is a temporary helper to avoid merge conflict between trees. */ -static inline int hmm_vma_fault(struct hmm_range *range, bool block) +static inline int hmm_vma_fault(struct hmm_mirror *mirror, + struct hmm_range *range, bool block) { long ret; @@ -552,7 +553,7 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block) range->default_flags = 0; range->pfn_flags_mask = -1UL; - ret = hmm_range_register(range, range->vma->vm_mm, + ret = hmm_range_register(range, mirror, range->start, range->end, PAGE_SHIFT); if (ret) diff --git a/mm/hmm.c b/mm/hmm.c index 824e7e160d8167..fa1b04fcfc2549 100644 --- a/mm/hmm.c +++ b/mm/hmm.c @@ -927,7 +927,7 @@ static void hmm_pfns_clear(struct hmm_range *range, * Track updates to the CPU page table see include/linux/hmm.h */ int hmm_range_register(struct hmm_range *range, - struct mm_struct *mm, + struct hmm_mirror *mirror, unsigned long start, unsigned long end, unsigned page_shift) @@ -935,7 +935,6 @@ int hmm_range_register(struct hmm_range *range, unsigned long mask = ((1UL << page_shift) - 1UL); range->valid = false; - range->hmm = NULL; if ((start & mask) || (end & mask)) return -EINVAL; @@ -946,15 +945,12 @@ int hmm_range_register(struct hmm_range *range, range->start = start; range->end = end; - range->hmm = hmm_get_or_create(mm); - if (!range->hmm) - return -EFAULT; - /* Check if hmm_mm_destroy() was call. */ - if (range->hmm->mm == NULL || range->hmm->dead) { - hmm_put(range->hmm); + if (mirror->hmm->mm == NULL || mirror->hmm->dead) return -EFAULT; - } + + range->hmm = mirror->hmm; + kref_get(&range->hmm->kref); /* Initialize range to track CPU page table update */ mutex_lock(&range->hmm->lock); -- 2.21.0