[PATCH 5.4 11/98] smb3: Avoid Mid pending list corruption
From: Greg Kroah-Hartman
Date: Tue Dec 01 2020 - 04:27:21 EST
From: Rohith Surabattula <rohiths@xxxxxxxxxxxxx>
commit ac873aa3dc21707c47db5db6608b38981c731afe upstream.
When reconnect happens Mid queue can be corrupted when both
demultiplex and offload thread try to dequeue the MID from the
pending list.
These patches address a problem found during decryption offload:
CIFS: VFS: trying to dequeue a deleted mid
that could cause a refcount use after free:
Workqueue: smb3decryptd smb2_decrypt_offload [cifs]
Signed-off-by: Rohith Surabattula <rohiths@xxxxxxxxxxxxx>
Reviewed-by: Pavel Shilovsky <pshilov@xxxxxxxxxxxxx>
CC: Stable <stable@xxxxxxxxxxxxxxx> #5.4+
Signed-off-by: Steve French <stfrench@xxxxxxxxxxxxx>
Signed-off-by: Greg Kroah-Hartman <gregkh@xxxxxxxxxxxxxxxxxxx>
---
fs/cifs/smb2ops.c | 55 +++++++++++++++++++++++++++++++++++++++++++++---------
1 file changed, 46 insertions(+), 9 deletions(-)
--- a/fs/cifs/smb2ops.c
+++ b/fs/cifs/smb2ops.c
@@ -259,7 +259,7 @@ smb2_revert_current_mid(struct TCP_Serve
}
static struct mid_q_entry *
-smb2_find_mid(struct TCP_Server_Info *server, char *buf)
+__smb2_find_mid(struct TCP_Server_Info *server, char *buf, bool dequeue)
{
struct mid_q_entry *mid;
struct smb2_sync_hdr *shdr = (struct smb2_sync_hdr *)buf;
@@ -276,6 +276,10 @@ smb2_find_mid(struct TCP_Server_Info *se
(mid->mid_state == MID_REQUEST_SUBMITTED) &&
(mid->command == shdr->Command)) {
kref_get(&mid->refcount);
+ if (dequeue) {
+ list_del_init(&mid->qhead);
+ mid->mid_flags |= MID_DELETED;
+ }
spin_unlock(&GlobalMid_Lock);
return mid;
}
@@ -284,6 +288,18 @@ smb2_find_mid(struct TCP_Server_Info *se
return NULL;
}
+static struct mid_q_entry *
+smb2_find_mid(struct TCP_Server_Info *server, char *buf)
+{
+ return __smb2_find_mid(server, buf, false);
+}
+
+static struct mid_q_entry *
+smb2_find_dequeue_mid(struct TCP_Server_Info *server, char *buf)
+{
+ return __smb2_find_mid(server, buf, true);
+}
+
static void
smb2_dump_detail(void *buf, struct TCP_Server_Info *server)
{
@@ -4028,7 +4044,10 @@ handle_read_data(struct TCP_Server_Info
cifs_dbg(FYI, "%s: server returned error %d\n",
__func__, rdata->result);
/* normal error on read response */
- dequeue_mid(mid, false);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_RECEIVED;
+ else
+ dequeue_mid(mid, false);
return 0;
}
@@ -4052,7 +4071,10 @@ handle_read_data(struct TCP_Server_Info
cifs_dbg(FYI, "%s: data offset (%u) beyond end of smallbuf\n",
__func__, data_offset);
rdata->result = -EIO;
- dequeue_mid(mid, rdata->result);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_MALFORMED;
+ else
+ dequeue_mid(mid, rdata->result);
return 0;
}
@@ -4068,21 +4090,30 @@ handle_read_data(struct TCP_Server_Info
cifs_dbg(FYI, "%s: data offset (%u) beyond 1st page of response\n",
__func__, data_offset);
rdata->result = -EIO;
- dequeue_mid(mid, rdata->result);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_MALFORMED;
+ else
+ dequeue_mid(mid, rdata->result);
return 0;
}
if (data_len > page_data_size - pad_len) {
/* data_len is corrupt -- discard frame */
rdata->result = -EIO;
- dequeue_mid(mid, rdata->result);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_MALFORMED;
+ else
+ dequeue_mid(mid, rdata->result);
return 0;
}
rdata->result = init_read_bvec(pages, npages, page_data_size,
cur_off, &bvec);
if (rdata->result != 0) {
- dequeue_mid(mid, rdata->result);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_MALFORMED;
+ else
+ dequeue_mid(mid, rdata->result);
return 0;
}
@@ -4097,7 +4128,10 @@ handle_read_data(struct TCP_Server_Info
/* read response payload cannot be in both buf and pages */
WARN_ONCE(1, "buf can not contain only a part of read data");
rdata->result = -EIO;
- dequeue_mid(mid, rdata->result);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_MALFORMED;
+ else
+ dequeue_mid(mid, rdata->result);
return 0;
}
@@ -4108,7 +4142,10 @@ handle_read_data(struct TCP_Server_Info
if (length < 0)
return length;
- dequeue_mid(mid, false);
+ if (is_offloaded)
+ mid->mid_state = MID_RESPONSE_RECEIVED;
+ else
+ dequeue_mid(mid, false);
return length;
}
@@ -4137,7 +4174,7 @@ static void smb2_decrypt_offload(struct
}
dw->server->lstrp = jiffies;
- mid = smb2_find_mid(dw->server, dw->buf);
+ mid = smb2_find_dequeue_mid(dw->server, dw->buf);
if (mid == NULL)
cifs_dbg(FYI, "mid not found\n");
else {