[PATCH] thunderbolt: fix a missing-check bug

From: Wenwen Wang
Date: Sat Oct 20 2018 - 15:53:54 EST


In ring_work(), the first while loop is used to collect all completed
frames from the ring buffer. In each iteration of this loop, the flag of
the frame, i.e., 'ring->descriptors[ring->tail].flags' is firstly check to
see whether the frame is completed. If yes, the descriptor of the frame,
including the flag, is then copied. It is worth noting that the descriptor
is actually in a DMA region, which is allocated through
dma_alloc_coherent() in tb_ring_alloc(). Given that the device can also
access the DMA region, a malicious device controlled by an attacker can
race to modify the flag of the frame after the check but before the copy.
By doing so, the attacker can bypass the check and supply uncompleted
frame, which can cause undefined behavior of the kernel and introduce
potential security risk.

This patch firstly copies the flag into a local variable 'desc_flags' and
then performs the check and copy using 'desc_flags'. Through this way, the
above issue can be avoided.

Signed-off-by: Wenwen Wang <wang6495@xxxxxxx>
---
drivers/thunderbolt/nhi.c | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/drivers/thunderbolt/nhi.c b/drivers/thunderbolt/nhi.c
index 5cd6bdf..22bd6cf 100644
--- a/drivers/thunderbolt/nhi.c
+++ b/drivers/thunderbolt/nhi.c
@@ -215,6 +215,7 @@ static void ring_work(struct work_struct *work)
struct ring_frame *frame;
bool canceled = false;
unsigned long flags;
+ enum ring_desc_flags desc_flags;
LIST_HEAD(done);

spin_lock_irqsave(&ring->lock, flags);
@@ -228,8 +229,8 @@ static void ring_work(struct work_struct *work)
}

while (!ring_empty(ring)) {
- if (!(ring->descriptors[ring->tail].flags
- & RING_DESC_COMPLETED))
+ desc_flags = ring->descriptors[ring->tail].flags;
+ if (!(desc_flags & RING_DESC_COMPLETED))
break;
frame = list_first_entry(&ring->in_flight, typeof(*frame),
list);
@@ -238,7 +239,7 @@ static void ring_work(struct work_struct *work)
frame->size = ring->descriptors[ring->tail].length;
frame->eof = ring->descriptors[ring->tail].eof;
frame->sof = ring->descriptors[ring->tail].sof;
- frame->flags = ring->descriptors[ring->tail].flags;
+ frame->flags = desc_flags;
}
ring->tail = (ring->tail + 1) % ring->size;
}
--
2.7.4