[PATCH v2 2/2] Drivers: hv: vmbus: Add a channel ring buffer mutex lock
From: Kimberly Brown
Date: Thu Feb 21 2019 - 22:47:21 EST
The "_show" functions that access channel ring buffer data are
vulnerable to a race condition that can result in a NULL pointer
dereference. This problem was discussed here:
https://lkml.org/lkml/2018/10/18/779
To prevent this from occurring, add a new mutex lock,
"ring_buffer_mutex", to the vmbus_channel struct.
Acquire/release "ring_buffer_mutex" in the functions that can set the
ring buffer pointer to NULL: vmbus_free_ring() and __vmbus_open().
Acquire/release "ring_buffer_mutex" in the four channel-level "_show"
functions that access ring buffer data. Remove the "const" qualifier
from the "struct vmbus_channel *chan" parameter of the channel-level
"_show" functions so that "ring_buffer_mutex" can be acquired/released
in these functions.
Acquire/release "ring_buffer_mutex" in hv_ringbuffer_get_debuginfo().
Pass the channel pointer to hv_ringbuffer_get_debuginfo() so that
"ring_buffer_mutex" can be accessed in this function.
Signed-off-by: Kimberly Brown <kimbrownkd@xxxxxxxxx>
---
drivers/hv/channel.c | 5 ++
drivers/hv/channel_mgmt.c | 1 +
drivers/hv/ring_buffer.c | 11 +++-
drivers/hv/vmbus_drv.c | 111 ++++++++++++++++++++++++--------------
include/linux/hyperv.h | 10 +++-
5 files changed, 96 insertions(+), 42 deletions(-)
diff --git a/drivers/hv/channel.c b/drivers/hv/channel.c
index 23381c41d087..7770e97e4202 100644
--- a/drivers/hv/channel.c
+++ b/drivers/hv/channel.c
@@ -82,8 +82,10 @@ EXPORT_SYMBOL_GPL(vmbus_setevent);
/* vmbus_free_ring - drop mapping of ring buffer */
void vmbus_free_ring(struct vmbus_channel *channel)
{
+ mutex_lock(&channel->ring_buffer_mutex);
hv_ringbuffer_cleanup(&channel->outbound);
hv_ringbuffer_cleanup(&channel->inbound);
+ mutex_unlock(&channel->ring_buffer_mutex);
if (channel->ringbuffer_page) {
__free_pages(channel->ringbuffer_page,
@@ -241,8 +243,11 @@ static int __vmbus_open(struct vmbus_channel *newchannel,
vmbus_teardown_gpadl(newchannel, newchannel->ringbuffer_gpadlhandle);
newchannel->ringbuffer_gpadlhandle = 0;
error_clean_ring:
+ mutex_lock(&newchannel->ring_buffer_mutex);
hv_ringbuffer_cleanup(&newchannel->outbound);
hv_ringbuffer_cleanup(&newchannel->inbound);
+ mutex_unlock(&newchannel->ring_buffer_mutex);
+
newchannel->state = CHANNEL_OPEN_STATE;
return err;
}
diff --git a/drivers/hv/channel_mgmt.c b/drivers/hv/channel_mgmt.c
index 62703b354d6d..769873cddfe5 100644
--- a/drivers/hv/channel_mgmt.c
+++ b/drivers/hv/channel_mgmt.c
@@ -329,6 +329,7 @@ static struct vmbus_channel *alloc_channel(void)
spin_lock_init(&channel->lock);
init_completion(&channel->rescind_event);
+ mutex_init(&channel->ring_buffer_mutex);
INIT_LIST_HEAD(&channel->sc_list);
INIT_LIST_HEAD(&channel->percpu_list);
diff --git a/drivers/hv/ring_buffer.c b/drivers/hv/ring_buffer.c
index 9e8b31ccc142..35de60d2c1e8 100644
--- a/drivers/hv/ring_buffer.c
+++ b/drivers/hv/ring_buffer.c
@@ -167,13 +167,18 @@ hv_get_ringbuffer_availbytes(const struct hv_ring_buffer_info *rbi,
/* Get various debug metrics for the specified ring buffer. */
int hv_ringbuffer_get_debuginfo(const struct hv_ring_buffer_info *ring_info,
- struct hv_ring_buffer_debug_info *debug_info)
+ struct hv_ring_buffer_debug_info *debug_info,
+ struct vmbus_channel *channel)
{
u32 bytes_avail_towrite;
u32 bytes_avail_toread;
- if (!ring_info->ring_buffer)
+ mutex_lock(&channel->ring_buffer_mutex);
+
+ if (!ring_info->ring_buffer) {
+ mutex_unlock(&channel->ring_buffer_mutex);
return -EINVAL;
+ }
hv_get_ringbuffer_availbytes(ring_info,
&bytes_avail_toread,
@@ -184,6 +189,8 @@ int hv_ringbuffer_get_debuginfo(const struct hv_ring_buffer_info *ring_info,
debug_info->current_write_index = ring_info->ring_buffer->write_index;
debug_info->current_interrupt_mask
= ring_info->ring_buffer->interrupt_mask;
+ mutex_unlock(&channel->ring_buffer_mutex);
+
return 0;
}
EXPORT_SYMBOL_GPL(hv_ringbuffer_get_debuginfo);
diff --git a/drivers/hv/vmbus_drv.c b/drivers/hv/vmbus_drv.c
index b02bcf1a9380..1ff767795d0a 100644
--- a/drivers/hv/vmbus_drv.c
+++ b/drivers/hv/vmbus_drv.c
@@ -345,9 +345,8 @@ static ssize_t out_intr_mask_show(struct device *dev,
if (!hv_dev->channel)
return -ENODEV;
-
ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->outbound,
- &outbound);
+ &outbound, hv_dev->channel);
if (ret < 0)
return ret;
@@ -366,7 +365,7 @@ static ssize_t out_read_index_show(struct device *dev,
return -ENODEV;
ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->outbound,
- &outbound);
+ &outbound, hv_dev->channel);
if (ret < 0)
return ret;
return sprintf(buf, "%d\n", outbound.current_read_index);
@@ -385,7 +384,7 @@ static ssize_t out_write_index_show(struct device *dev,
return -ENODEV;
ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->outbound,
- &outbound);
+ &outbound, hv_dev->channel);
if (ret < 0)
return ret;
return sprintf(buf, "%d\n", outbound.current_write_index);
@@ -404,7 +403,7 @@ static ssize_t out_read_bytes_avail_show(struct device *dev,
return -ENODEV;
ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->outbound,
- &outbound);
+ &outbound, hv_dev->channel);
if (ret < 0)
return ret;
return sprintf(buf, "%d\n", outbound.bytes_avail_toread);
@@ -423,7 +422,7 @@ static ssize_t out_write_bytes_avail_show(struct device *dev,
return -ENODEV;
ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->outbound,
- &outbound);
+ &outbound, hv_dev->channel);
if (ret < 0)
return ret;
return sprintf(buf, "%d\n", outbound.bytes_avail_towrite);
@@ -440,7 +439,8 @@ static ssize_t in_intr_mask_show(struct device *dev,
if (!hv_dev->channel)
return -ENODEV;
- ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound);
+ ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound,
+ hv_dev->channel);
if (ret < 0)
return ret;
@@ -458,7 +458,8 @@ static ssize_t in_read_index_show(struct device *dev,
if (!hv_dev->channel)
return -ENODEV;
- ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound);
+ ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound,
+ hv_dev->channel);
if (ret < 0)
return ret;
@@ -476,7 +477,8 @@ static ssize_t in_write_index_show(struct device *dev,
if (!hv_dev->channel)
return -ENODEV;
- ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound);
+ ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound,
+ hv_dev->channel);
if (ret < 0)
return ret;
@@ -495,7 +497,8 @@ static ssize_t in_read_bytes_avail_show(struct device *dev,
if (!hv_dev->channel)
return -ENODEV;
- ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound);
+ ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound,
+ hv_dev->channel);
if (ret < 0)
return ret;
@@ -514,7 +517,8 @@ static ssize_t in_write_bytes_avail_show(struct device *dev,
if (!hv_dev->channel)
return -ENODEV;
- ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound);
+ ret = hv_ringbuffer_get_debuginfo(&hv_dev->channel->inbound, &inbound,
+ hv_dev->channel);
if (ret < 0)
return ret;
@@ -1409,7 +1413,7 @@ static void vmbus_chan_release(struct kobject *kobj)
struct vmbus_chan_attribute {
struct attribute attr;
- ssize_t (*show)(const struct vmbus_channel *chan, char *buf);
+ ssize_t (*show)(struct vmbus_channel *chan, char *buf);
ssize_t (*store)(struct vmbus_channel *chan,
const char *buf, size_t count);
};
@@ -1428,7 +1432,7 @@ static ssize_t vmbus_chan_attr_show(struct kobject *kobj,
{
const struct vmbus_chan_attribute *attribute
= container_of(attr, struct vmbus_chan_attribute, attr);
- const struct vmbus_channel *chan
+ struct vmbus_channel *chan
= container_of(kobj, struct vmbus_channel, kobj);
if (!attribute->show)
@@ -1441,58 +1445,89 @@ static const struct sysfs_ops vmbus_chan_sysfs_ops = {
.show = vmbus_chan_attr_show,
};
-static ssize_t out_mask_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t out_mask_show(struct vmbus_channel *channel, char *buf)
{
const struct hv_ring_buffer_info *rbi = &channel->outbound;
+ ssize_t ret;
+
+ mutex_lock(&channel->ring_buffer_mutex);
- if (!rbi->ring_buffer)
+ if (!rbi->ring_buffer) {
+ mutex_unlock(&channel->ring_buffer_mutex);
return -EINVAL;
+ }
- return sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+ ret = sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+ mutex_unlock(&channel->ring_buffer_mutex);
+
+ return ret;
}
static VMBUS_CHAN_ATTR_RO(out_mask);
-static ssize_t in_mask_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t in_mask_show(struct vmbus_channel *channel, char *buf)
{
const struct hv_ring_buffer_info *rbi = &channel->inbound;
+ ssize_t ret;
- if (!rbi->ring_buffer)
+ mutex_lock(&channel->ring_buffer_mutex);
+
+ if (!rbi->ring_buffer) {
+ mutex_unlock(&channel->ring_buffer_mutex);
return -EINVAL;
+ }
+
+ ret = sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+ mutex_unlock(&channel->ring_buffer_mutex);
- return sprintf(buf, "%u\n", rbi->ring_buffer->interrupt_mask);
+ return ret;
}
static VMBUS_CHAN_ATTR_RO(in_mask);
-static ssize_t read_avail_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t read_avail_show(struct vmbus_channel *channel, char *buf)
{
const struct hv_ring_buffer_info *rbi = &channel->inbound;
+ ssize_t ret;
+
+ mutex_lock(&channel->ring_buffer_mutex);
- if (!rbi->ring_buffer)
+ if (!rbi->ring_buffer) {
+ mutex_unlock(&channel->ring_buffer_mutex);
return -EINVAL;
+ }
+
+ ret = sprintf(buf, "%u\n", hv_get_bytes_to_read(rbi));
+ mutex_unlock(&channel->ring_buffer_mutex);
- return sprintf(buf, "%u\n", hv_get_bytes_to_read(rbi));
+ return ret;
}
static VMBUS_CHAN_ATTR_RO(read_avail);
-static ssize_t write_avail_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t write_avail_show(struct vmbus_channel *channel, char *buf)
{
const struct hv_ring_buffer_info *rbi = &channel->outbound;
+ ssize_t ret;
- if (!rbi->ring_buffer)
+ mutex_lock(&channel->ring_buffer_mutex);
+
+ if (!rbi->ring_buffer) {
+ mutex_unlock(&channel->ring_buffer_mutex);
return -EINVAL;
+ }
- return sprintf(buf, "%u\n", hv_get_bytes_to_write(rbi));
+ ret = sprintf(buf, "%u\n", hv_get_bytes_to_write(rbi));
+ mutex_unlock(&channel->ring_buffer_mutex);
+
+ return ret;
}
static VMBUS_CHAN_ATTR_RO(write_avail);
-static ssize_t show_target_cpu(const struct vmbus_channel *channel, char *buf)
+static ssize_t show_target_cpu(struct vmbus_channel *channel, char *buf)
{
return sprintf(buf, "%u\n", channel->target_cpu);
}
static VMBUS_CHAN_ATTR(cpu, S_IRUGO, show_target_cpu, NULL);
-static ssize_t channel_pending_show(const struct vmbus_channel *channel,
- char *buf)
+static ssize_t channel_pending_show(struct vmbus_channel *channel, char *buf)
{
if (!channel->offermsg.monitor_allocated)
return -EINVAL;
@@ -1503,8 +1538,7 @@ static ssize_t channel_pending_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(pending, S_IRUGO, channel_pending_show, NULL);
-static ssize_t channel_latency_show(const struct vmbus_channel *channel,
- char *buf)
+static ssize_t channel_latency_show(struct vmbus_channel *channel, char *buf)
{
if (!channel->offermsg.monitor_allocated)
return -EINVAL;
@@ -1515,19 +1549,19 @@ static ssize_t channel_latency_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(latency, S_IRUGO, channel_latency_show, NULL);
-static ssize_t channel_interrupts_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t channel_interrupts_show(struct vmbus_channel *channel, char *buf)
{
return sprintf(buf, "%llu\n", channel->interrupts);
}
static VMBUS_CHAN_ATTR(interrupts, S_IRUGO, channel_interrupts_show, NULL);
-static ssize_t channel_events_show(const struct vmbus_channel *channel, char *buf)
+static ssize_t channel_events_show(struct vmbus_channel *channel, char *buf)
{
return sprintf(buf, "%llu\n", channel->sig_events);
}
static VMBUS_CHAN_ATTR(events, S_IRUGO, channel_events_show, NULL);
-static ssize_t channel_intr_in_full_show(const struct vmbus_channel *channel,
+static ssize_t channel_intr_in_full_show(struct vmbus_channel *channel,
char *buf)
{
return sprintf(buf, "%llu\n",
@@ -1535,7 +1569,7 @@ static ssize_t channel_intr_in_full_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(intr_in_full, 0444, channel_intr_in_full_show, NULL);
-static ssize_t channel_intr_out_empty_show(const struct vmbus_channel *channel,
+static ssize_t channel_intr_out_empty_show(struct vmbus_channel *channel,
char *buf)
{
return sprintf(buf, "%llu\n",
@@ -1543,7 +1577,7 @@ static ssize_t channel_intr_out_empty_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(intr_out_empty, 0444, channel_intr_out_empty_show, NULL);
-static ssize_t channel_out_full_first_show(const struct vmbus_channel *channel,
+static ssize_t channel_out_full_first_show(struct vmbus_channel *channel,
char *buf)
{
return sprintf(buf, "%llu\n",
@@ -1551,7 +1585,7 @@ static ssize_t channel_out_full_first_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(out_full_first, 0444, channel_out_full_first_show, NULL);
-static ssize_t channel_out_full_total_show(const struct vmbus_channel *channel,
+static ssize_t channel_out_full_total_show(struct vmbus_channel *channel,
char *buf)
{
return sprintf(buf, "%llu\n",
@@ -1559,7 +1593,7 @@ static ssize_t channel_out_full_total_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(out_full_total, 0444, channel_out_full_total_show, NULL);
-static ssize_t subchannel_monitor_id_show(const struct vmbus_channel *channel,
+static ssize_t subchannel_monitor_id_show(struct vmbus_channel *channel,
char *buf)
{
if (!channel->offermsg.monitor_allocated)
@@ -1569,8 +1603,7 @@ static ssize_t subchannel_monitor_id_show(const struct vmbus_channel *channel,
}
static VMBUS_CHAN_ATTR(monitor_id, S_IRUGO, subchannel_monitor_id_show, NULL);
-static ssize_t subchannel_id_show(const struct vmbus_channel *channel,
- char *buf)
+static ssize_t subchannel_id_show(struct vmbus_channel *channel, char *buf)
{
return sprintf(buf, "%u\n",
channel->offermsg.offer.sub_channel_index);
diff --git a/include/linux/hyperv.h b/include/linux/hyperv.h
index 64698ec8f2ac..6a6f79d7beba 100644
--- a/include/linux/hyperv.h
+++ b/include/linux/hyperv.h
@@ -934,6 +934,13 @@ struct vmbus_channel {
* full outbound ring buffer.
*/
u64 out_full_first;
+
+ /*
+ * The mutex lock that protects the channel's ring buffers. It's used to
+ * prevent the ring buffer pointers from being set to NULL while a
+ * function is accessing ring buffer data.
+ */
+ struct mutex ring_buffer_mutex;
};
static inline bool is_hvsock_channel(const struct vmbus_channel *c)
@@ -1207,7 +1214,8 @@ struct hv_ring_buffer_debug_info {
int hv_ringbuffer_get_debuginfo(const struct hv_ring_buffer_info *ring_info,
- struct hv_ring_buffer_debug_info *debug_info);
+ struct hv_ring_buffer_debug_info *debug_info,
+ struct vmbus_channel *channel);
/* Vmbus interface */
#define vmbus_driver_register(driver) \
--
2.17.1