Re: [patch 37/38] x86/bpf: Emit call depth accounting if required

From: Alexei Starovoitov
Date: Tue Jul 19 2022 - 01:30:30 EST


On Sat, Jul 16, 2022 at 4:18 PM Thomas Gleixner <tglx@xxxxxxxxxxxxx> wrote:
>
> Ensure that calls in BPF jitted programs are emitting call depth accounting
> when enabled to keep the call/return balanced. The return thunk jump is
> already injected due to the earlier retbleed mitigations.
>
> Signed-off-by: Thomas Gleixner <tglx@xxxxxxxxxxxxx>
> Cc: Alexei Starovoitov <ast@xxxxxxxxxx>
> Cc: Daniel Borkmann <daniel@xxxxxxxxxxxxx>
> ---
> arch/x86/include/asm/alternative.h | 6 +++++
> arch/x86/kernel/callthunks.c | 19 ++++++++++++++++
> arch/x86/net/bpf_jit_comp.c | 43 ++++++++++++++++++++++++-------------
> 3 files changed, 53 insertions(+), 15 deletions(-)
>
> --- a/arch/x86/include/asm/alternative.h
> +++ b/arch/x86/include/asm/alternative.h
> @@ -95,6 +95,7 @@ extern void callthunks_patch_module_call
> extern void callthunks_module_free(struct module *mod);
> extern void *callthunks_translate_call_dest(void *dest);
> extern bool is_callthunk(void *addr);
> +extern int x86_call_depth_emit_accounting(u8 **pprog, void *func);
> #else
> static __always_inline void callthunks_patch_builtin_calls(void) {}
> static __always_inline void
> @@ -109,6 +110,11 @@ static __always_inline bool is_callthunk
> {
> return false;
> }
> +static __always_inline int x86_call_depth_emit_accounting(u8 **pprog,
> + void *func)
> +{
> + return 0;
> +}
> #endif
>
> #ifdef CONFIG_SMP
> --- a/arch/x86/kernel/callthunks.c
> +++ b/arch/x86/kernel/callthunks.c
> @@ -706,6 +706,25 @@ int callthunk_get_kallsym(unsigned int s
> return ret;
> }
>
> +#ifdef CONFIG_BPF_JIT
> +int x86_call_depth_emit_accounting(u8 **pprog, void *func)
> +{
> + unsigned int tmpl_size = callthunk_desc.template_size;
> + void *tmpl = callthunk_desc.template;
> +
> + if (!thunks_initialized)
> + return 0;
> +
> + /* Is function call target a thunk? */
> + if (is_callthunk(func))
> + return 0;
> +
> + memcpy(*pprog, tmpl, tmpl_size);
> + *pprog += tmpl_size;
> + return tmpl_size;
> +}
> +#endif
> +
> #ifdef CONFIG_MODULES
> void noinline callthunks_patch_module_calls(struct callthunk_sites *cs,
> struct module *mod)
> --- a/arch/x86/net/bpf_jit_comp.c
> +++ b/arch/x86/net/bpf_jit_comp.c
> @@ -340,6 +340,12 @@ static int emit_call(u8 **pprog, void *f
> return emit_patch(pprog, func, ip, 0xE8);
> }
>
> +static int emit_rsb_call(u8 **pprog, void *func, void *ip)
> +{
> + x86_call_depth_emit_accounting(pprog, func);
> + return emit_patch(pprog, func, ip, 0xE8);
> +}
> +
> static int emit_jump(u8 **pprog, void *func, void *ip)
> {
> return emit_patch(pprog, func, ip, 0xE9);
> @@ -1431,19 +1437,26 @@ st: if (is_imm8(insn->off))
> break;
>
> /* call */
> - case BPF_JMP | BPF_CALL:
> + case BPF_JMP | BPF_CALL: {
> + int offs;
> +
> func = (u8 *) __bpf_call_base + imm32;
> if (tail_call_reachable) {
> /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> EMIT3_off32(0x48, 0x8B, 0x85,
> -round_up(bpf_prog->aux->stack_depth, 8) - 8);
> - if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7))
> + if (!imm32)
> return -EINVAL;
> + offs = 7 + x86_call_depth_emit_accounting(&prog, func);

It's a bit hard to read all the macro magic in patches 28-30,
but I suspect the asm inside
callthunk_desc.template
that will be emitted here before the call
will do
some math on %rax
movq %rax, PER_CPU_VAR(__x86_call_depth).

Only %rax register is scratched by the callthunk_desc, right?
If so, it's ok for all cases except this one.
See the comment few lines above
after if (tail_call_reachable)
and commit ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall
handling in JIT")
We use %rax to keep the tail_call count.
The callthunk_desc would need to preserve %rax.
I guess extra push %rax/pop %rax would do it.

> } else {
> - if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
> + if (!imm32)
> return -EINVAL;
> + offs = x86_call_depth_emit_accounting(&prog, func);
> }
> + if (emit_call(&prog, func, image + addrs[i - 1] + offs))
> + return -EINVAL;
> break;
> + }