Signed-off-by: Vitor Massaru Iha <vitor@xxxxxxxxxxx> --- include/kunit/test.h | 1 + lib/kunit/try-catch.c | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/include/kunit/test.h b/include/kunit/test.h index 59f3144f009a..49c38bdcb93e 100644 --- a/include/kunit/test.h +++ b/include/kunit/test.h @@ -222,6 +222,7 @@ struct kunit { * protect it with some type of lock. */ struct list_head resources; /* Protected by lock. */ + struct mm_struct *mm; }; void kunit_init_test(struct kunit *test, const char *name, char *log); diff --git a/lib/kunit/try-catch.c b/lib/kunit/try-catch.c index 0dd434e40487..f677c2f2a51a 100644 --- a/lib/kunit/try-catch.c +++ b/lib/kunit/try-catch.c @@ -11,7 +11,8 @@ #include <linux/completion.h> #include <linux/kernel.h> #include <linux/kthread.h> - +#include <linux/sched/mm.h> +#include <linux/sched/task.h> #include "try-catch-impl.h" void __noreturn kunit_try_catch_throw(struct kunit_try_catch *try_catch) @@ -24,8 +25,17 @@ EXPORT_SYMBOL_GPL(kunit_try_catch_throw); static int kunit_generic_run_threadfn_adapter(void *data) { struct kunit_try_catch *try_catch = data; + struct kunit *test = try_catch->test; + + if (test->mm != NULL) + kthread_use_mm(try_catch->test->mm); try_catch->try(try_catch->context); + if (try_catch->test->mm) { + if (test->mm != NULL) + kthread_unuse_mm(try_catch->test->mm); + try_catch->test->mm = NULL; + } complete_and_exit(try_catch->try_completion, 0); } @@ -65,6 +75,9 @@ void kunit_try_catch_run(struct kunit_try_catch *try_catch, void *context) try_catch->context = context; try_catch->try_completion = &try_completion; try_catch->try_result = 0; + + test->mm = get_task_mm(current); + task_struct = kthread_run(kunit_generic_run_threadfn_adapter, try_catch, "kunit_try_catch_thread"); -- 2.26.2