[PATCH -mm] fault-inject: avoid unwanted data race to task->fail_nth

From: Akinobu Mita
Date: Thu Jul 13 2017 - 12:15:18 EST


The fault-inject-make-fail-nth-read-write-interface-symmetric.patch in
-mm tree allows users to set task->fail_nth for non current task by procfs.
On the other hand, the current task's fail_nth is decreased to zero in
fault-injection path without any specific locks.

So we need to prevent the task->fail_nth from being unexpected value by
data races (for example, setting task->fail_nth to zero while decreasing
the current->fail_nth). In this fix, we use READ_ONCE() and WRITE_ONCE()
to prevent the compiler from creating unsolicited accesses.

Cc: Dmitry Vyukov <dvyukov@xxxxxxxxxx>
Reported-by: Dmitry Vyukov <dvyukov@xxxxxxxxxx>
Signed-off-by: Akinobu Mita <akinobu.mita@xxxxxxxxx>
---
fs/proc/base.c | 5 +++--
lib/fault-inject.c | 7 +++++--
2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/fs/proc/base.c b/fs/proc/base.c
index ecc8a25..719c2e9 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -1370,7 +1370,7 @@ static ssize_t proc_fail_nth_write(struct file *file, const char __user *buf,
task = get_proc_task(file_inode(file));
if (!task)
return -ESRCH;
- task->fail_nth = n;
+ WRITE_ONCE(task->fail_nth, n);
put_task_struct(task);

return count;
@@ -1386,7 +1386,8 @@ static ssize_t proc_fail_nth_read(struct file *file, char __user *buf,
task = get_proc_task(file_inode(file));
if (!task)
return -ESRCH;
- len = snprintf(numbuf, sizeof(numbuf), "%u\n", task->fail_nth);
+ len = snprintf(numbuf, sizeof(numbuf), "%u\n",
+ READ_ONCE(task->fail_nth));
len = simple_read_from_buffer(buf, count, ppos, numbuf, len);
put_task_struct(task);

diff --git a/lib/fault-inject.c b/lib/fault-inject.c
index 09ac73c1..7d315fd 100644
--- a/lib/fault-inject.c
+++ b/lib/fault-inject.c
@@ -107,9 +107,12 @@ static inline bool fail_stacktrace(struct fault_attr *attr)

bool should_fail(struct fault_attr *attr, ssize_t size)
{
- if (in_task() && current->fail_nth) {
- if (--current->fail_nth == 0)
+ if (in_task()) {
+ unsigned int fail_nth = READ_ONCE(current->fail_nth);
+
+ if (fail_nth && !WRITE_ONCE(current->fail_nth, fail_nth - 1))
goto fail;
+
return false;
}

--
2.7.4