Signed-off-by: Benjamin Tissoires <bentiss@xxxxxxxxxx> --- drivers/hid/bpf/hid_bpf_dispatch.c | 111 ++++++++++++++++++++++-------------- drivers/hid/bpf/hid_bpf_dispatch.h | 4 +- drivers/hid/bpf/hid_bpf_jmp_table.c | 30 +++++++--- 3 files changed, 92 insertions(+), 53 deletions(-) diff --git a/drivers/hid/bpf/hid_bpf_dispatch.c b/drivers/hid/bpf/hid_bpf_dispatch.c index 81073db6c617..7ede657f459b 100644 --- a/drivers/hid/bpf/hid_bpf_dispatch.c +++ b/drivers/hid/bpf/hid_bpf_dispatch.c @@ -70,7 +70,7 @@ dispatch_hid_bpf_device_event(struct hid_device *hdev, enum hid_report_type type memset(ctx_kern.data, 0, hdev->bpf.allocated_data); memcpy(ctx_kern.data, data, *size); - ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_DEVICE_EVENT, &ctx_kern); + ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_DEVICE_EVENT, &ctx_kern, false); if (ret < 0) return ERR_PTR(ret); @@ -122,7 +122,7 @@ u8 *call_hid_bpf_rdesc_fixup(struct hid_device *hdev, u8 *rdesc, unsigned int *s memcpy(ctx_kern.data, rdesc, min_t(unsigned int, *size, HID_MAX_DESCRIPTOR_SIZE)); - ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_RDESC_FIXUP, &ctx_kern); + ret = hid_bpf_prog_run(hdev, HID_BPF_PROG_TYPE_RDESC_FIXUP, &ctx_kern, false); if (ret < 0) goto ignore_bpf; @@ -205,7 +205,7 @@ int hid_bpf_reconnect(struct hid_device *hdev) static int do_hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, hid_bpf_cb_t prog_fn, struct bpf_prog *prog, - __u32 flags) + __u32 flags, bool sleepable) { int fd, err; @@ -213,7 +213,7 @@ static int do_hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_typ if (err) return err; - fd = __hid_bpf_attach_prog(hdev, prog_type, prog_fn, prog, flags); + fd = __hid_bpf_attach_prog(hdev, prog_type, prog_fn, prog, flags, sleepable); if (fd < 0) return fd; @@ -228,6 +228,56 @@ static int do_hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_typ return fd; } +static int +hid_bpf_attach_prog(unsigned int hid_id, enum hid_bpf_prog_type prog_type, + hid_bpf_cb_t prog_fn, __u32 flags, void *prog__aux, + bool sleepable) +{ + struct bpf_prog_aux *aux = (struct bpf_prog_aux *)prog__aux; + struct bpf_prog *prog = aux->prog; + struct hid_device *hdev; + struct device *dev; + int err, fd; + + if (!hid_bpf_ops) + return -EINVAL; + + if ((flags & ~HID_BPF_FLAG_MASK)) + return -EINVAL; + + if (prog_type < 0 || prog_type >= HID_BPF_PROG_TYPE_MAX) + return -EINVAL; + + dev = bus_find_device(hid_bpf_ops->bus_type, NULL, &hid_id, device_match_id); + if (!dev) + return -EINVAL; + + hdev = to_hid_device(dev); + + /* + * take a ref on the prog itself, it will be released + * on errors or when it'll be detached + */ + prog = bpf_prog_inc_not_zero(prog); + if (IS_ERR(prog)) { + err = PTR_ERR(prog); + goto out_dev_put; + } + + fd = do_hid_bpf_attach_prog(hdev, prog_type, prog_fn, prog, flags, sleepable); + if (fd < 0) { + err = fd; + goto out_prog_put; + } + + return fd; + + out_prog_put: + bpf_prog_put(prog); + out_dev_put: + put_device(dev); + return err; +} /* Disables missing prototype warnings */ __bpf_kfunc_start_defs(); @@ -272,50 +322,22 @@ __bpf_kfunc int hid_bpf_attach_prog_impl(unsigned int hid_id, enum hid_bpf_prog_type prog_type, hid_bpf_cb_t prog_fn__async, __u32 flags, void *prog__aux) { - struct bpf_prog_aux *aux = (struct bpf_prog_aux *)prog__aux; - struct bpf_prog *prog = aux->prog; - struct hid_device *hdev; - struct device *dev; - int err, fd; - - if (!hid_bpf_ops) - return -EINVAL; - - if ((flags & ~HID_BPF_FLAG_MASK)) - return -EINVAL; - - if (prog_type < 0 || prog_type >= HID_BPF_PROG_TYPE_MAX) - return -EINVAL; + return hid_bpf_attach_prog(hid_id, prog_type, prog_fn__async, flags, prog__aux, false); +} - dev = bus_find_device(hid_bpf_ops->bus_type, NULL, &hid_id, device_match_id); - if (!dev) +__bpf_kfunc int +hid_bpf_attach_sleepable_prog_impl(unsigned int hid_id, enum hid_bpf_prog_type prog_type, + hid_bpf_cb_t prog_fn__s_async, __u32 flags, void *prog__aux) +{ + switch (prog_type) { + case HID_BPF_PROG_TYPE_RAW_REQUEST: + /* OK */ + break; + default: return -EINVAL; - - hdev = to_hid_device(dev); - - /* - * take a ref on the prog itself, it will be released - * on errors or when it'll be detached - */ - prog = bpf_prog_inc_not_zero(prog); - if (IS_ERR(prog)) { - err = PTR_ERR(prog); - goto out_dev_put; - } - - fd = do_hid_bpf_attach_prog(hdev, prog_type, prog_fn__async, prog, flags); - if (fd < 0) { - err = fd; - goto out_prog_put; } - return fd; - - out_prog_put: - bpf_prog_put(prog); - out_dev_put: - put_device(dev); - return err; + return hid_bpf_attach_prog(hid_id, prog_type, prog_fn__s_async, flags, prog__aux, true); } /** @@ -538,6 +560,7 @@ __bpf_kfunc_end_defs(); BTF_KFUNCS_START(hid_bpf_kfunc_ids) BTF_ID_FLAGS(func, hid_bpf_get_data, KF_RET_NULL) BTF_ID_FLAGS(func, hid_bpf_attach_prog_impl, KF_SLEEPABLE) +BTF_ID_FLAGS(func, hid_bpf_attach_sleepable_prog_impl, KF_SLEEPABLE) BTF_ID_FLAGS(func, hid_bpf_allocate_context, KF_ACQUIRE | KF_RET_NULL | KF_SLEEPABLE) BTF_ID_FLAGS(func, hid_bpf_release_context, KF_RELEASE | KF_SLEEPABLE) BTF_ID_FLAGS(func, hid_bpf_hw_request, KF_SLEEPABLE) diff --git a/drivers/hid/bpf/hid_bpf_dispatch.h b/drivers/hid/bpf/hid_bpf_dispatch.h index 60455a2af216..f9833603e49f 100644 --- a/drivers/hid/bpf/hid_bpf_dispatch.h +++ b/drivers/hid/bpf/hid_bpf_dispatch.h @@ -14,10 +14,10 @@ typedef int (*hid_bpf_cb_t)(struct hid_bpf_ctx *hid_ctx); int __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, int (prog_fn__async)(struct hid_bpf_ctx *hid_ctx), - struct bpf_prog *prog, __u32 flags); + struct bpf_prog *prog, __u32 flags, bool sleepable); void __hid_bpf_destroy_device(struct hid_device *hdev); int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, - struct hid_bpf_ctx_kern *ctx_kern); + struct hid_bpf_ctx_kern *ctx_kern, bool is_sleepable); int hid_bpf_reconnect(struct hid_device *hdev); struct bpf_prog; diff --git a/drivers/hid/bpf/hid_bpf_jmp_table.c b/drivers/hid/bpf/hid_bpf_jmp_table.c index 8d53d41b599b..4cceff354962 100644 --- a/drivers/hid/bpf/hid_bpf_jmp_table.c +++ b/drivers/hid/bpf/hid_bpf_jmp_table.c @@ -39,6 +39,7 @@ struct hid_bpf_prog_entry { struct hid_bpf_prog_cb { struct bpf_prog *prog; void *fn; + bool sleepable; }; struct hid_bpf_jmp_table { @@ -99,14 +100,20 @@ static int hid_bpf_program_count(struct hid_device *hdev, } int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, - struct hid_bpf_ctx_kern *ctx_kern) + struct hid_bpf_ctx_kern *ctx_kern, bool is_sleepable) { struct hid_bpf_prog_list *prog_list; bpf_callback_t prog_fn; int i, idx, err = 0; - rcu_read_lock(); - prog_list = rcu_dereference(hdev->bpf.progs[type]); + if (is_sleepable) { + prog_list = READ_ONCE(hdev->bpf.progs[type]); + rcu_read_lock_trace(); + might_fault(); + } else { + rcu_read_lock(); + prog_list = rcu_dereference(hdev->bpf.progs[type]); + } if (!prog_list) goto out_unlock; @@ -117,6 +124,10 @@ int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, if (!test_bit(idx, jmp_table.enabled)) continue; + /* prevent a sleepable program to be run in a non sleepable context */ + if (!is_sleepable && jmp_table.prog_cbs[idx].sleepable) + continue; + ctx_kern->ctx.index = idx; prog_fn = jmp_table.prog_cbs[idx].fn; migrate_disable(); @@ -129,7 +140,10 @@ int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type, } out_unlock: - rcu_read_unlock(); + if (is_sleepable) + rcu_read_unlock_trace(); + else + rcu_read_unlock(); return err; } @@ -279,7 +293,7 @@ static void hid_bpf_release_progs(struct work_struct *work) * Insert the given BPF program represented by its function call in the jmp table. * Returns the index in the jump table or a negative error. */ -static int hid_bpf_insert_prog(struct bpf_prog *prog, hid_bpf_cb_t prog_fn) +static int hid_bpf_insert_prog(struct bpf_prog *prog, hid_bpf_cb_t prog_fn, bool sleepable) { int i, index = -1, err = -EINVAL; @@ -289,6 +303,7 @@ static int hid_bpf_insert_prog(struct bpf_prog *prog, hid_bpf_cb_t prog_fn) /* mark the index as used */ jmp_table.prog_cbs[i].fn = prog_fn; jmp_table.prog_cbs[i].prog = prog; + jmp_table.prog_cbs[i].sleepable = sleepable; index = i; __set_bit(i, jmp_table.enabled); } @@ -340,7 +355,8 @@ static const struct bpf_link_ops hid_bpf_link_lops = { /* called from syscall */ noinline int __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, - hid_bpf_cb_t prog_fn, struct bpf_prog *prog, __u32 flags) + hid_bpf_cb_t prog_fn, struct bpf_prog *prog, __u32 flags, + bool sleepable) { struct bpf_link_primer link_primer; struct hid_bpf_link *link; @@ -370,7 +386,7 @@ __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, goto err_unlock; } - prog_table_idx = hid_bpf_insert_prog(prog, prog_fn); + prog_table_idx = hid_bpf_insert_prog(prog, prog_fn, sleepable); /* if the jmp table is full, abort */ if (prog_table_idx < 0) { err = prog_table_idx; -- 2.44.0