Re: [PATCH 1/2] HID: bpf: remove double fdget()
From: Benjamin Tissoires
Date: Tue Jan 23 2024 - 12:20:13 EST
On Tue, Jan 23, 2024 at 5:41 PM Benjamin Tissoires <bentiss@xxxxxxxxxx> wrote:
>
> When the kfunc hid_bpf_attach_prog() is called, we called twice fdget():
> one for fetching the type of the bpf program, and one for actually
> attaching the program to the device.
>
> The problem is that between those two calls, we have no guarantees that
> the prog_fd is still the same file descriptor for the given program.
>
> Solve this by calling bpf_prog_get() earlier, and use this to fetch the
> program type.
>
> Reported-by: Dan Carpenter <dan.carpenter@xxxxxxxxxx>
> Link: https://lore.kernel.org/bpf/CAO-hwJJ8vh8JD3-P43L-_CLNmPx0hWj44aom0O838vfP4=_1CA@xxxxxxxxxxxxxx/T/#t
> Cc: stable@xxxxxxxxxxxxxxx
Sigh, I forgot:
Fixes: f5c27da4e3c8 ("HID: initial BPF implementation")
Cheers,
Benjamin
> Signed-off-by: Benjamin Tissoires <bentiss@xxxxxxxxxx>
> ---
> drivers/hid/bpf/hid_bpf_dispatch.c | 66 ++++++++++++++++++++++++-------------
> drivers/hid/bpf/hid_bpf_dispatch.h | 4 +--
> drivers/hid/bpf/hid_bpf_jmp_table.c | 20 ++---------
> 3 files changed, 49 insertions(+), 41 deletions(-)
>
> diff --git a/drivers/hid/bpf/hid_bpf_dispatch.c b/drivers/hid/bpf/hid_bpf_dispatch.c
> index d9ef45fcaeab..5111d1fef0d3 100644
> --- a/drivers/hid/bpf/hid_bpf_dispatch.c
> +++ b/drivers/hid/bpf/hid_bpf_dispatch.c
> @@ -241,6 +241,39 @@ int hid_bpf_reconnect(struct hid_device *hdev)
> return 0;
> }
>
> +static int do_hid_bpf_attach_prog(struct hid_device *hdev, int prog_fd, struct bpf_prog *prog,
> + __u32 flags)
> +{
> + int fd, err, prog_type;
> +
> + prog_type = hid_bpf_get_prog_attach_type(prog);
> + if (prog_type < 0)
> + return prog_type;
> +
> + if (prog_type >= HID_BPF_PROG_TYPE_MAX)
> + return -EINVAL;
> +
> + if (prog_type == HID_BPF_PROG_TYPE_DEVICE_EVENT) {
> + err = hid_bpf_allocate_event_data(hdev);
> + if (err)
> + return err;
> + }
> +
> + fd = __hid_bpf_attach_prog(hdev, prog_type, prog_fd, prog, flags);
> + if (fd < 0)
> + return fd;
> +
> + if (prog_type == HID_BPF_PROG_TYPE_RDESC_FIXUP) {
> + err = hid_bpf_reconnect(hdev);
> + if (err) {
> + close_fd(fd);
> + return err;
> + }
> + }
> +
> + return fd;
> +}
> +
> /**
> * hid_bpf_attach_prog - Attach the given @prog_fd to the given HID device
> *
> @@ -257,18 +290,13 @@ noinline int
> hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
> {
> struct hid_device *hdev;
> + struct bpf_prog *prog;
> struct device *dev;
> - int fd, err, prog_type = hid_bpf_get_prog_attach_type(prog_fd);
> + int fd;
>
> if (!hid_bpf_ops)
> return -EINVAL;
>
> - if (prog_type < 0)
> - return prog_type;
> -
> - if (prog_type >= HID_BPF_PROG_TYPE_MAX)
> - return -EINVAL;
> -
> if ((flags & ~HID_BPF_FLAG_MASK))
> return -EINVAL;
>
> @@ -278,23 +306,17 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
>
> hdev = to_hid_device(dev);
>
> - if (prog_type == HID_BPF_PROG_TYPE_DEVICE_EVENT) {
> - err = hid_bpf_allocate_event_data(hdev);
> - if (err)
> - return err;
> - }
> + /*
> + * take a ref on the prog itself, it will be released
> + * on errors or when it'll be detached
> + */
> + prog = bpf_prog_get(prog_fd);
> + if (IS_ERR(prog))
> + return PTR_ERR(prog);
>
> - fd = __hid_bpf_attach_prog(hdev, prog_type, prog_fd, flags);
> + fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags);
> if (fd < 0)
> - return fd;
> -
> - if (prog_type == HID_BPF_PROG_TYPE_RDESC_FIXUP) {
> - err = hid_bpf_reconnect(hdev);
> - if (err) {
> - close_fd(fd);
> - return err;
> - }
> - }
> + bpf_prog_put(prog);
>
> return fd;
> }
> diff --git a/drivers/hid/bpf/hid_bpf_dispatch.h b/drivers/hid/bpf/hid_bpf_dispatch.h
> index 63dfc8605cd2..fbe0639d09f2 100644
> --- a/drivers/hid/bpf/hid_bpf_dispatch.h
> +++ b/drivers/hid/bpf/hid_bpf_dispatch.h
> @@ -12,9 +12,9 @@ struct hid_bpf_ctx_kern {
>
> int hid_bpf_preload_skel(void);
> void hid_bpf_free_links_and_skel(void);
> -int hid_bpf_get_prog_attach_type(int prog_fd);
> +int hid_bpf_get_prog_attach_type(struct bpf_prog *prog);
> int __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type, int prog_fd,
> - __u32 flags);
> + struct bpf_prog *prog, __u32 flags);
> void __hid_bpf_destroy_device(struct hid_device *hdev);
> int hid_bpf_prog_run(struct hid_device *hdev, enum hid_bpf_prog_type type,
> struct hid_bpf_ctx_kern *ctx_kern);
> diff --git a/drivers/hid/bpf/hid_bpf_jmp_table.c b/drivers/hid/bpf/hid_bpf_jmp_table.c
> index eca34b7372f9..12f7cebddd73 100644
> --- a/drivers/hid/bpf/hid_bpf_jmp_table.c
> +++ b/drivers/hid/bpf/hid_bpf_jmp_table.c
> @@ -333,15 +333,10 @@ static int hid_bpf_insert_prog(int prog_fd, struct bpf_prog *prog)
> return err;
> }
>
> -int hid_bpf_get_prog_attach_type(int prog_fd)
> +int hid_bpf_get_prog_attach_type(struct bpf_prog *prog)
> {
> - struct bpf_prog *prog = NULL;
> - int i;
> int prog_type = HID_BPF_PROG_TYPE_UNDEF;
> -
> - prog = bpf_prog_get(prog_fd);
> - if (IS_ERR(prog))
> - return PTR_ERR(prog);
> + int i;
>
> for (i = 0; i < HID_BPF_PROG_TYPE_MAX; i++) {
> if (hid_bpf_btf_ids[i] == prog->aux->attach_btf_id) {
> @@ -350,8 +345,6 @@ int hid_bpf_get_prog_attach_type(int prog_fd)
> }
> }
>
> - bpf_prog_put(prog);
> -
> return prog_type;
> }
>
> @@ -388,19 +381,13 @@ static const struct bpf_link_ops hid_bpf_link_lops = {
> /* called from syscall */
> noinline int
> __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
> - int prog_fd, __u32 flags)
> + int prog_fd, struct bpf_prog *prog, __u32 flags)
> {
> struct bpf_link_primer link_primer;
> struct hid_bpf_link *link;
> - struct bpf_prog *prog = NULL;
> struct hid_bpf_prog_entry *prog_entry;
> int cnt, err = -EINVAL, prog_table_idx = -1;
>
> - /* take a ref on the prog itself */
> - prog = bpf_prog_get(prog_fd);
> - if (IS_ERR(prog))
> - return PTR_ERR(prog);
> -
> mutex_lock(&hid_bpf_attach_lock);
>
> link = kzalloc(sizeof(*link), GFP_USER);
> @@ -467,7 +454,6 @@ __hid_bpf_attach_prog(struct hid_device *hdev, enum hid_bpf_prog_type prog_type,
> err_unlock:
> mutex_unlock(&hid_bpf_attach_lock);
>
> - bpf_prog_put(prog);
> kfree(link);
>
> return err;
>
> --
> 2.43.0
>