[RESEND][PATCH v3 15/17] static_call: Handle tail-calls

From: Peter Zijlstra
Date: Tue Mar 24 2020 - 10:25:15 EST


GCC can turn our static_call(name)(args...) into a tail call, in which
case we get a JMP.d32 into the trampoline (which then does a further
tail-call).

Teach objtool to recognise and mark these in .static_call_sites and
adjust the code patching to deal with this.

Signed-off-by: Peter Zijlstra (Intel) <peterz@xxxxxxxxxxxxx>
---
arch/x86/kernel/static_call.c | 4 ++--
include/linux/static_call.h | 4 ++--
include/linux/static_call_types.h | 7 +++++++
kernel/static_call.c | 21 +++++++++++++--------
tools/include/linux/static_call_types.h | 7 +++++++
tools/objtool/check.c | 18 +++++++++++++-----
6 files changed, 44 insertions(+), 17 deletions(-)

--- a/arch/x86/kernel/static_call.c
+++ b/arch/x86/kernel/static_call.c
@@ -41,7 +41,7 @@ static void __static_call_transform(void
text_poke_bp(insn, code, size, NULL);
}

-void arch_static_call_transform(void *site, void *tramp, void *func)
+void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
{
mutex_lock(&text_mutex);

@@ -49,7 +49,7 @@ void arch_static_call_transform(void *si
__static_call_transform(tramp, jmp + !func, func);

if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site)
- __static_call_transform(site, !func, func);
+ __static_call_transform(site, 2*tail + !func, func);

mutex_unlock(&text_mutex);
}
--- a/include/linux/static_call.h
+++ b/include/linux/static_call.h
@@ -64,7 +64,7 @@
/*
* Either @site or @tramp can be NULL.
*/
-extern void arch_static_call_transform(void *site, void *tramp, void *func);
+extern void arch_static_call_transform(void *site, void *tramp, void *func, bool tail);
#define STATIC_CALL_TRAMP_ADDR(name) &STATIC_CALL_TRAMP(name)
#else
#define STATIC_CALL_TRAMP_ADDR(name) NULL
@@ -140,7 +140,7 @@ void __static_call_update(struct static_
{
cpus_read_lock();
WRITE_ONCE(key->func, func);
- arch_static_call_transform(NULL, tramp, func);
+ arch_static_call_transform(NULL, tramp, func, false);
cpus_read_unlock();
}

--- a/include/linux/static_call_types.h
+++ b/include/linux/static_call_types.h
@@ -14,6 +14,13 @@
#define STATIC_CALL_TRAMP_STR(name) __stringify(STATIC_CALL_TRAMP(name))

/*
+ * Flags in the low bits of static_call_site::key.
+ */
+#define STATIC_CALL_SITE_TAIL 1UL /* tail call */
+#define STATIC_CALL_SITE_INIT 2UL /* init section */
+#define STATIC_CALL_SITE_FLAGS 3UL
+
+/*
* The static call site table needs to be created by external tooling (objtool
* or a compiler plugin).
*/
--- a/kernel/static_call.c
+++ b/kernel/static_call.c
@@ -15,8 +15,6 @@ extern struct static_call_site __start_s

static bool static_call_initialized;

-#define STATIC_CALL_INIT 1UL
-
/* mutex to protect key modules/sites */
static DEFINE_MUTEX(static_call_mutex);

@@ -39,18 +37,23 @@ static inline void *static_call_addr(str
static inline struct static_call_key *static_call_key(const struct static_call_site *site)
{
return (struct static_call_key *)
- (((long)site->key + (long)&site->key) & ~STATIC_CALL_INIT);
+ (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
}

/* These assume the key is word-aligned. */
static inline bool static_call_is_init(struct static_call_site *site)
{
- return ((long)site->key + (long)&site->key) & STATIC_CALL_INIT;
+ return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
+}
+
+static inline bool static_call_is_tail(struct static_call_site *site)
+{
+ return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
}

static inline void static_call_set_init(struct static_call_site *site)
{
- site->key = ((long)static_call_key(site) | STATIC_CALL_INIT) -
+ site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
(long)&site->key;
}

@@ -104,7 +107,7 @@ void __static_call_update(struct static_

key->func = func;

- arch_static_call_transform(NULL, tramp, func);
+ arch_static_call_transform(NULL, tramp, func, false);

/*
* If uninitialized, we'll not update the callsites, but they still
@@ -153,7 +156,8 @@ void __static_call_update(struct static_
continue;
}

- arch_static_call_transform(site_addr, NULL, func);
+ arch_static_call_transform(site_addr, NULL, func,
+ static_call_is_tail(site));
}
}

@@ -197,7 +201,8 @@ static int __static_call_init(struct mod
key->next = site_mod;
}

- arch_static_call_transform(site_addr, NULL, key->func);
+ arch_static_call_transform(site_addr, NULL, key->func,
+ static_call_is_tail(site));
}

return 0;
--- a/tools/include/linux/static_call_types.h
+++ b/tools/include/linux/static_call_types.h
@@ -14,6 +14,13 @@
#define STATIC_CALL_TRAMP_STR(name) __stringify(STATIC_CALL_TRAMP(name))

/*
+ * Flags in the low bits of static_call_site::key.
+ */
+#define STATIC_CALL_SITE_TAIL 1UL /* tail call */
+#define STATIC_CALL_SITE_INIT 2UL /* init section */
+#define STATIC_CALL_SITE_FLAGS 3UL
+
+/*
* The static call site table needs to be created by external tooling (objtool
* or a compiler plugin).
*/
--- a/tools/objtool/check.c
+++ b/tools/objtool/check.c
@@ -585,6 +585,10 @@ static int add_jump_destinations(struct
} else {
/* external sibling call */
insn->call_dest = rela->sym;
+ if (insn->call_dest->static_call_tramp) {
+ list_add_tail(&insn->static_call_node,
+ &file->static_call_list);
+ }
continue;
}

@@ -636,6 +640,10 @@ static int add_jump_destinations(struct

/* internal sibling call */
insn->call_dest = insn->jump_dest->func;
+ if (insn->call_dest->static_call_tramp) {
+ list_add_tail(&insn->static_call_node,
+ &file->static_call_list);
+ }
}
}
}
@@ -1348,6 +1356,10 @@ static int decode_sections(struct objtoo
if (ret)
return ret;

+ ret = read_static_call_tramps(file);
+ if (ret)
+ return ret;
+
ret = add_jump_destinations(file);
if (ret)
return ret;
@@ -1372,10 +1384,6 @@ static int decode_sections(struct objtoo
if (ret)
return ret;

- ret = read_static_call_tramps(file);
- if (ret)
- return ret;
-
return 0;
}

@@ -2505,7 +2513,7 @@ static int create_static_call_sections(s
}
memset(rela, 0, sizeof(*rela));
rela->sym = key_sym;
- rela->addend = 0;
+ rela->addend = is_sibling_call(insn) ? STATIC_CALL_SITE_TAIL : 0;
rela->type = R_X86_64_PC32;
rela->offset = idx * sizeof(struct static_call_site) + 4;
list_add_tail(&rela->list, &rela_sec->rela_list);