RE: [PATCH] PCI: hv: Move completion variable from stack to heap in hv_compose_msi_msg()

From: Michael Kelley
Date: Wed May 26 2021 - 14:27:31 EST


From: longli@xxxxxxxxxxxxxxxxx <longli@xxxxxxxxxxxxxxxxx> Sent: Wednesday, May 12, 2021 1:07 AM
>
> hv_compose_msi_msg() may be called with interrupt disabled. It calls
> wait_for_completion() in a loop and may exit the loop earlier if the device is
> being ejected or it's hitting other errors. However the VSP may send
> completion packet after the loop exit and the completion variable is no
> longer valid on the stack. This results in a kernel oops.
>
> Fix this by relocating completion variable from stack to heap, and use hbus
> to maintain a list of leftover completions for future cleanup if necessary.

Interesting problem. I haven't reviewed the details of your implementation
because I'd like to propose an alternate approach to solving the problem.

You have fixed the problem for hv_compose_msi_msg(), but it seems like the
same problem could occur in other places in pci-hyperv.c where a VMbus
request is sent, and waiting for the response could be aborted by the device
being rescinded.

The current code (and with your patch) passes the guest memory address of
the completion packet to Hyper-V as the requestID. Hyper-V responds and
passes back the requestID, whereupon hv_pci_onchannelcallback() treats it
as the guest memory address of the completion packet. This all assumes that
Hyper-V is trusted and that it doesn't pass back a bogus value that will be
treated as a guest memory address. But Andrea Parri has been updating
other VMbus drivers (like netvsc and storvsc) to *not* pass guest memory
addresses as the requestID. The pci-hyperv.c driver has not been fixed in this
regard, but I think this patch could take big step in that direction.

My alternate approach is as follows:
1. For reach PCI VMbus channel, keep a 64-bit counter. When a VMbus message
is to be sent, increment the counter atomically, and send the next value as the
requestID. The counter will not wrap-around in any practical time period, so
the requestIDs are essentially unique. Or just read a clock value to get a unique
requestID.
2. Also keep a per-channel list of mappings from requestID to the guest memory
address of the completion packet. For PCI channels, there will be very few
requests outstanding concurrently, so this can be a simple linked list, protected
by a spin lock.
3. Before sending a new VMbus message that is expecting a response, add the
mapping to the list. The guest memory address can be for a stack local, like
the current code.
4. When the sending function completes, either because the response was
received, or because wait_for_response() aborted, remove the mapping from
the linked list.
5. hv_pci_onchannelcallback() gets the requestID from Hyper-V and looks it
up in the linked list. If there's no match in the linked list, the completion
response from Hyper-V is ignored. It's either a late response or a completely
bogus response from Hyper-V. If there is a match, then the address of the
completion packet is available and valid. The completion function will need
to run while the spin lock is held on the linked list, so that the completion
packet address is ensured to remain valid while the completion function
executes.

I don't think my proposed approach is any more complicated that what your
patch does, and it is a step in the direction of fully hardening the
pci-hyperv.c driver.

This approach is a bit different from netvsc and storvsc because those drivers
must handle lots of in-flight requests, and searching a linked list in the
onchannelcallback function would be too slow. The overall idea is the same,
but a different approach is used to generate requestIDs and to map
between requestIDs and guest memory addresses.

Thoughts?

Michael

>
> Signed-off-by: Long Li <longli@xxxxxxxxxxxxx>
> ---
> drivers/pci/controller/pci-hyperv.c | 97 +++++++++++++++++++----------
> 1 file changed, 65 insertions(+), 32 deletions(-)
>
> diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
> index 9499ae3275fe..29fe26e2193c 100644
> --- a/drivers/pci/controller/pci-hyperv.c
> +++ b/drivers/pci/controller/pci-hyperv.c
> @@ -473,6 +473,9 @@ struct hv_pcibus_device {
> struct msi_controller msi_chip;
> struct irq_domain *irq_domain;
>
> + struct list_head compose_msi_msg_ctxt_list;
> + spinlock_t compose_msi_msg_ctxt_list_lock;
> +
> spinlock_t retarget_msi_interrupt_lock;
>
> struct workqueue_struct *wq;
> @@ -552,6 +555,17 @@ struct hv_pci_compl {
> s32 completion_status;
> };
>
> +struct compose_comp_ctxt {
> + struct hv_pci_compl comp_pkt;
> + struct tran_int_desc int_desc;
> +};
> +
> +struct compose_msi_msg_ctxt {
> + struct list_head list;
> + struct pci_packet pci_pkt;
> + struct compose_comp_ctxt comp;
> +};
> +
> static void hv_pci_onchannelcallback(void *context);
>
> /**
> @@ -1293,11 +1307,6 @@ static void hv_irq_unmask(struct irq_data *data)
> pci_msi_unmask_irq(data);
> }
>
> -struct compose_comp_ctxt {
> - struct hv_pci_compl comp_pkt;
> - struct tran_int_desc int_desc;
> -};
> -
> static void hv_pci_compose_compl(void *context, struct pci_response *resp,
> int resp_packet_size)
> {
> @@ -1373,16 +1382,12 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> struct pci_bus *pbus;
> struct pci_dev *pdev;
> struct cpumask *dest;
> - struct compose_comp_ctxt comp;
> struct tran_int_desc *int_desc;
> - struct {
> - struct pci_packet pci_pkt;
> - union {
> - struct pci_create_interrupt v1;
> - struct pci_create_interrupt2 v2;
> - } int_pkts;
> - } __packed ctxt;
> -
> + struct compose_msi_msg_ctxt *ctxt;
> + union {
> + struct pci_create_interrupt v1;
> + struct pci_create_interrupt2 v2;
> + } int_pkts;
> u32 size;
> int ret;
>
> @@ -1402,18 +1407,24 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> hv_int_desc_free(hpdev, int_desc);
> }
>
> + ctxt = kzalloc(sizeof(*ctxt), GFP_ATOMIC);
> + if (!ctxt)
> + goto drop_reference;
> +
> int_desc = kzalloc(sizeof(*int_desc), GFP_ATOMIC);
> - if (!int_desc)
> + if (!int_desc) {
> + kfree(ctxt);
> goto drop_reference;
> + }
>
> - memset(&ctxt, 0, sizeof(ctxt));
> - init_completion(&comp.comp_pkt.host_event);
> - ctxt.pci_pkt.completion_func = hv_pci_compose_compl;
> - ctxt.pci_pkt.compl_ctxt = &comp;
> + memset(ctxt, 0, sizeof(*ctxt));
> + init_completion(&ctxt->comp.comp_pkt.host_event);
> + ctxt->pci_pkt.completion_func = hv_pci_compose_compl;
> + ctxt->pci_pkt.compl_ctxt = &ctxt->comp;
>
> switch (hbus->protocol_version) {
> case PCI_PROTOCOL_VERSION_1_1:
> - size = hv_compose_msi_req_v1(&ctxt.int_pkts.v1,
> + size = hv_compose_msi_req_v1(&int_pkts.v1,
> dest,
> hpdev->desc.win_slot.slot,
> cfg->vector);
> @@ -1421,7 +1432,7 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>
> case PCI_PROTOCOL_VERSION_1_2:
> case PCI_PROTOCOL_VERSION_1_3:
> - size = hv_compose_msi_req_v2(&ctxt.int_pkts.v2,
> + size = hv_compose_msi_req_v2(&int_pkts.v2,
> dest,
> hpdev->desc.win_slot.slot,
> cfg->vector);
> @@ -1434,17 +1445,18 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> */
> dev_err(&hbus->hdev->device,
> "Unexpected vPCI protocol, update driver.");
> + kfree(ctxt);
> goto free_int_desc;
> }
>
> - ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &ctxt.int_pkts,
> - size, (unsigned long)&ctxt.pci_pkt,
> + ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &int_pkts,
> + size, (unsigned long)&ctxt->pci_pkt,
> VM_PKT_DATA_INBAND,
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
> if (ret) {
> dev_err(&hbus->hdev->device,
> - "Sending request for interrupt failed: 0x%x",
> - comp.comp_pkt.completion_status);
> + "Sending request for interrupt failed: 0x%x", ret);
> + kfree(ctxt);
> goto free_int_desc;
> }
>
> @@ -1458,7 +1470,7 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> * Since this function is called with IRQ locks held, can't
> * do normal wait for completion; instead poll.
> */
> - while (!try_wait_for_completion(&comp.comp_pkt.host_event)) {
> + while (!try_wait_for_completion(&ctxt->comp.comp_pkt.host_event)) {
> unsigned long flags;
>
> /* 0xFFFF means an invalid PCI VENDOR ID. */
> @@ -1494,10 +1506,11 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>
> tasklet_enable(&channel->callback_event);
>
> - if (comp.comp_pkt.completion_status < 0) {
> + if (ctxt->comp.comp_pkt.completion_status < 0) {
> dev_err(&hbus->hdev->device,
> "Request for interrupt failed: 0x%x",
> - comp.comp_pkt.completion_status);
> + ctxt->comp.comp_pkt.completion_status);
> + kfree(ctxt);
> goto free_int_desc;
> }
>
> @@ -1506,23 +1519,36 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> * irq_set_chip_data() here would be appropriate, but the lock it takes
> * is already held.
> */
> - *int_desc = comp.int_desc;
> + *int_desc = ctxt->comp.int_desc;
> data->chip_data = int_desc;
>
> /* Pass up the result. */
> - msg->address_hi = comp.int_desc.address >> 32;
> - msg->address_lo = comp.int_desc.address & 0xffffffff;
> - msg->data = comp.int_desc.data;
> + msg->address_hi = ctxt->comp.int_desc.address >> 32;
> + msg->address_lo = ctxt->comp.int_desc.address & 0xffffffff;
> + msg->data = ctxt->comp.int_desc.data;
>
> put_pcichild(hpdev);
> + kfree(ctxt);
> return;
>
> enable_tasklet:
> tasklet_enable(&channel->callback_event);
> +
> + /*
> + * Move uncompleted context to the leftover list.
> + * The host may send completion at a later time, and we ignore this
> + * completion but keep the memory reference valid.
> + */
> + spin_lock(&hbus->compose_msi_msg_ctxt_list_lock);
> + list_add_tail(&ctxt->list, &hbus->compose_msi_msg_ctxt_list);
> + spin_unlock(&hbus->compose_msi_msg_ctxt_list_lock);
> +
> free_int_desc:
> kfree(int_desc);
> +
> drop_reference:
> put_pcichild(hpdev);
> +
> return_null_message:
> msg->address_hi = 0;
> msg->address_lo = 0;
> @@ -3076,9 +3102,11 @@ static int hv_pci_probe(struct hv_device *hdev,
> INIT_LIST_HEAD(&hbus->children);
> INIT_LIST_HEAD(&hbus->dr_list);
> INIT_LIST_HEAD(&hbus->resources_for_children);
> + INIT_LIST_HEAD(&hbus->compose_msi_msg_ctxt_list);
> spin_lock_init(&hbus->config_lock);
> spin_lock_init(&hbus->device_list_lock);
> spin_lock_init(&hbus->retarget_msi_interrupt_lock);
> + spin_lock_init(&hbus->compose_msi_msg_ctxt_list_lock);
> hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
> hbus->sysdata.domain);
> if (!hbus->wq) {
> @@ -3282,6 +3310,7 @@ static int hv_pci_bus_exit(struct hv_device *hdev, bool
> keep_devs)
> static int hv_pci_remove(struct hv_device *hdev)
> {
> struct hv_pcibus_device *hbus;
> + struct compose_msi_msg_ctxt *ctxt, *tmp;
> int ret;
>
> hbus = hv_get_drvdata(hdev);
> @@ -3318,6 +3347,10 @@ static int hv_pci_remove(struct hv_device *hdev)
>
> hv_put_dom_num(hbus->sysdata.domain);
>
> + list_for_each_entry_safe(ctxt, tmp, &hbus->compose_msi_msg_ctxt_list, list) {
> + list_del(&ctxt->list);
> + kfree(ctxt);
> + }
> kfree(hbus);
> return ret;
> }
> --
> 2.27.0