[PATCH 01/23] net, sunrpc: convert rpc_cred.cr_count from atomic_t to refcount_t

From: Elena Reshetova
Date: Fri Mar 17 2017 - 08:11:15 EST


refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

Signed-off-by: Elena Reshetova <elena.reshetova@xxxxxxxxx>
Signed-off-by: Hans Liljestrand <ishkamiel@xxxxxxxxx>
Signed-off-by: Kees Cook <keescook@xxxxxxxxxxxx>
Signed-off-by: David Windsor <dwindsor@xxxxxxxxx>
---
include/linux/sunrpc/auth.h | 8 ++++----
net/sunrpc/auth.c | 12 ++++++------
2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/include/linux/sunrpc/auth.h b/include/linux/sunrpc/auth.h
index b1bc62b..bd36e0b 100644
--- a/include/linux/sunrpc/auth.h
+++ b/include/linux/sunrpc/auth.h
@@ -15,7 +15,7 @@
#include <linux/sunrpc/msg_prot.h>
#include <linux/sunrpc/xdr.h>

-#include <linux/atomic.h>
+#include <linux/refcount.h>
#include <linux/rcupdate.h>
#include <linux/uidgid.h>
#include <linux/utsname.h>
@@ -68,7 +68,7 @@ struct rpc_cred {
#endif
unsigned long cr_expire; /* when to gc */
unsigned long cr_flags; /* various flags */
- atomic_t cr_count; /* ref count */
+ refcount_t cr_count; /* ref count */

kuid_t cr_uid;

@@ -209,7 +209,7 @@ static inline
struct rpc_cred * get_rpccred(struct rpc_cred *cred)
{
if (cred != NULL)
- atomic_inc(&cred->cr_count);
+ refcount_inc(&cred->cr_count);
return cred;
}

@@ -226,7 +226,7 @@ struct rpc_cred * get_rpccred(struct rpc_cred *cred)
static inline struct rpc_cred *
get_rpccred_rcu(struct rpc_cred *cred)
{
- if (atomic_inc_not_zero(&cred->cr_count))
+ if (refcount_inc_not_zero(&cred->cr_count))
return cred;
return NULL;
}
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
index 2bff63a..b6439b9 100644
--- a/net/sunrpc/auth.c
+++ b/net/sunrpc/auth.c
@@ -310,7 +310,7 @@ rpcauth_unhash_cred(struct rpc_cred *cred)

cache_lock = &cred->cr_auth->au_credcache->lock;
spin_lock(cache_lock);
- ret = atomic_read(&cred->cr_count) == 0;
+ ret = refcount_read(&cred->cr_count) == 0;
if (ret)
rpcauth_unhash_cred_locked(cred);
spin_unlock(cache_lock);
@@ -470,12 +470,12 @@ rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
list_del_init(&cred->cr_lru);
number_cred_unused--;
freed++;
- if (atomic_read(&cred->cr_count) != 0)
+ if (refcount_read(&cred->cr_count) != 0)
continue;

cache_lock = &cred->cr_auth->au_credcache->lock;
spin_lock(cache_lock);
- if (atomic_read(&cred->cr_count) == 0) {
+ if (refcount_read(&cred->cr_count) == 0) {
get_rpccred(cred);
list_add_tail(&cred->cr_lru, free);
rpcauth_unhash_cred_locked(cred);
@@ -642,7 +642,7 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
{
INIT_HLIST_NODE(&cred->cr_hash);
INIT_LIST_HEAD(&cred->cr_lru);
- atomic_set(&cred->cr_count, 1);
+ refcount_set(&cred->cr_count, 1);
cred->cr_auth = auth;
cred->cr_ops = ops;
cred->cr_expire = jiffies;
@@ -715,12 +715,12 @@ put_rpccred(struct rpc_cred *cred)
return;
/* Fast path for unhashed credentials */
if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) {
- if (atomic_dec_and_test(&cred->cr_count))
+ if (refcount_dec_and_test(&cred->cr_count))
cred->cr_ops->crdestroy(cred);
return;
}

- if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
+ if (!refcount_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
return;
if (!list_empty(&cred->cr_lru)) {
number_cred_unused--;
--
2.7.4