[PATCH] cred: add get_cred_many and put_cred_many

From: Mateusz Guzik
Date: Mon Aug 07 2023 - 12:31:48 EST


Shaves back-to-back atomics in a few places.

Signed-off-by: Mateusz Guzik <mjguzik@xxxxxxxxx>
---
include/linux/cred.h | 27 +++++++++++++++++++++------
kernel/cred.c | 29 +++++++++++++++++------------
2 files changed, 38 insertions(+), 18 deletions(-)

diff --git a/include/linux/cred.h b/include/linux/cred.h
index 9ed9232af934..b2b570ba204a 100644
--- a/include/linux/cred.h
+++ b/include/linux/cred.h
@@ -226,12 +226,17 @@ static inline bool cap_ambient_invariant_ok(const struct cred *cred)
* Get a reference on the specified set of new credentials. The caller must
* release the reference.
*/
-static inline struct cred *get_new_cred(struct cred *cred)
+static inline struct cred *get_new_cred_many(struct cred *cred, int nr)
{
- atomic_inc(&cred->usage);
+ atomic_add(nr, &cred->usage);
return cred;
}

+static inline struct cred *get_new_cred(struct cred *cred)
+{
+ return get_new_cred_many(cred, 1);
+}
+
/**
* get_cred - Get a reference on a set of credentials
* @cred: The credentials to reference
@@ -245,14 +250,19 @@ static inline struct cred *get_new_cred(struct cred *cred)
* accidental alteration of a set of credentials that should be considered
* immutable.
*/
-static inline const struct cred *get_cred(const struct cred *cred)
+static inline const struct cred *get_cred_many(const struct cred *cred, int nr)
{
struct cred *nonconst_cred = (struct cred *) cred;
if (!cred)
return cred;
validate_creds(cred);
nonconst_cred->non_rcu = 0;
- return get_new_cred(nonconst_cred);
+ return get_new_cred_many(nonconst_cred, nr);
+}
+
+static inline const struct cred *get_cred(const struct cred *cred)
+{
+ return get_cred_many(cred, 1);
}

static inline const struct cred *get_cred_rcu(const struct cred *cred)
@@ -278,17 +288,22 @@ static inline const struct cred *get_cred_rcu(const struct cred *cred)
* on task_struct are attached by const pointers to prevent accidental
* alteration of otherwise immutable credential sets.
*/
-static inline void put_cred(const struct cred *_cred)
+static inline void put_cred_many(const struct cred *_cred, int nr)
{
struct cred *cred = (struct cred *) _cred;

if (cred) {
validate_creds(cred);
- if (atomic_dec_and_test(&(cred)->usage))
+ if (atomic_sub_and_test(nr, &cred->usage))
__put_cred(cred);
}
}

+static inline void put_cred(const struct cred *cred)
+{
+ put_cred_many(cred, 1);
+}
+
/**
* current_cred - Access the current task's subjective credentials
*
diff --git a/kernel/cred.c b/kernel/cred.c
index 811ad654abd1..8a506bc7c1b8 100644
--- a/kernel/cred.c
+++ b/kernel/cred.c
@@ -159,23 +159,30 @@ EXPORT_SYMBOL(__put_cred);
*/
void exit_creds(struct task_struct *tsk)
{
- struct cred *cred;
+ struct cred *real_cred, *cred;

kdebug("exit_creds(%u,%p,%p,{%d,%d})", tsk->pid, tsk->real_cred, tsk->cred,
atomic_read(&tsk->cred->usage),
read_cred_subscribers(tsk->cred));

- cred = (struct cred *) tsk->real_cred;
+ real_cred = (struct cred *) tsk->real_cred;
tsk->real_cred = NULL;
- validate_creds(cred);
- alter_cred_subscribers(cred, -1);
- put_cred(cred);

cred = (struct cred *) tsk->cred;
tsk->cred = NULL;
- validate_creds(cred);
- alter_cred_subscribers(cred, -1);
- put_cred(cred);
+
+ if (real_cred == cred) {
+ validate_creds(cred);
+ alter_cred_subscribers(cred, -2);
+ put_cred_many(cred, 2);
+ } else {
+ validate_creds(real_cred);
+ validate_creds(cred);
+ alter_cred_subscribers(real_cred, -1);
+ put_cred(real_cred);
+ alter_cred_subscribers(cred, -1);
+ put_cred(cred);
+ }

#ifdef CONFIG_KEYS_REQUEST_CACHE
key_put(tsk->cached_requested_key);
@@ -352,8 +359,7 @@ int copy_creds(struct task_struct *p, unsigned long clone_flags)
#endif
clone_flags & CLONE_THREAD
) {
- p->real_cred = get_cred(p->cred);
- get_cred(p->cred);
+ p->real_cred = get_cred_many(p->cred, 2);
alter_cred_subscribers(p->cred, 2);
kdebug("share_creds(%p{%d,%d})",
p->cred, atomic_read(&p->cred->usage),
@@ -517,8 +523,7 @@ int commit_creds(struct cred *new)
proc_id_connector(task, PROC_EVENT_GID);

/* release the old obj and subj refs both */
- put_cred(old);
- put_cred(old);
+ put_cred_many(old, 2);
return 0;
}
EXPORT_SYMBOL(commit_creds);
--
2.39.2