Currently check_vma() checks for VM_DONTCOPY for a device-dax when the dax region is not backed by devmap (i.e. PFN_MAP is not set). VM_DONTCOPY is set through madvise(MADV_DONTFORK) and it only sets it at an address returned from mmap(). check_vma() is called at devdax mmap hence checking VM_DONTCOPY prevents a process from mmap-ing the device. Let's not enforce MADV_DONTFORK at mmap(), but rather when it actually gets used (on fault). The assumptions don't change, as it is expected to still retain/madvise MADV_DONTFORK after mmap. Signed-off-by: Joao Martins <joao.m.martins@xxxxxxxxxx> --- drivers/dax/device.c | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/drivers/dax/device.c b/drivers/dax/device.c index 1af823b2fe6b..c6a7f5e12c54 100644 --- a/drivers/dax/device.c +++ b/drivers/dax/device.c @@ -14,7 +14,7 @@ #include "dax-private.h" #include "bus.h" -static int check_vma(struct dev_dax *dev_dax, struct vm_area_struct *vma, +static int check_vma_mmap(struct dev_dax *dev_dax, struct vm_area_struct *vma, const char *func) { struct dax_region *dax_region = dev_dax->region; @@ -41,17 +41,29 @@ static int check_vma(struct dev_dax *dev_dax, struct vm_area_struct *vma, return -EINVAL; } - if ((dax_region->pfn_flags & (PFN_DEV|PFN_MAP)) == PFN_DEV - && (vma->vm_flags & VM_DONTCOPY) == 0) { + if (!vma_is_dax(vma)) { dev_info_ratelimited(dev, - "%s: %s: fail, dax range requires MADV_DONTFORK\n", + "%s: %s: fail, vma is not DAX capable\n", current->comm, func); return -EINVAL; } - if (!vma_is_dax(vma)) { - dev_info_ratelimited(dev, - "%s: %s: fail, vma is not DAX capable\n", + return 0; +} + +static int check_vma(struct dev_dax *dev_dax, struct vm_area_struct *vma, + const char *func) +{ + int rc; + + rc = check_vma_mmap(dev_dax, vma, func); + if (rc < 0) + return rc; + + if ((dev_dax->region->pfn_flags & (PFN_DEV|PFN_MAP)) == PFN_DEV + && (vma->vm_flags & VM_DONTCOPY) == 0) { + dev_info_ratelimited(&dev_dax->dev, + "%s: %s: fail, dax range requires MADV_DONTFORK\n", current->comm, func); return -EINVAL; } @@ -315,7 +327,7 @@ static int dax_mmap(struct file *filp, struct vm_area_struct *vma) * fault time. */ id = dax_read_lock(); - rc = check_vma(dev_dax, vma, __func__); + rc = check_vma_mmap(dev_dax, vma, __func__); dax_read_unlock(id); if (rc) return rc; -- 2.17.1