Re: bpf: shift-out-of-bounds in tnum_rshift()

From: Eduard Zingerman
Date: Wed Oct 25 2023 - 13:34:50 EST


On Tue, 2023-10-24 at 14:40 +0200, Hao Sun wrote:
> Hi,
>
> The following program can trigger a shift-out-of-bounds in
> tnum_rshift(), called by scalar32_min_max_rsh():
>
> 0: (bc) w0 = w1
> 1: (bf) r2 = r0
> 2: (18) r3 = 0xd
> 4: (bc) w4 = w0
> 5: (bf) r5 = r0
> 6: (bf) r7 = r3
> 7: (bf) r8 = r4
> 8: (2f) r8 *= r5
> 9: (cf) r5 s>>= r5
> 10: (a6) if w8 < 0xfffffffb goto pc+10
> 11: (1f) r7 -= r5
> 12: (71) r6 = *(u8 *)(r1 +17)
> 13: (5f) r3 &= r8
> 14: (74) w2 >>= 30
> 15: (1f) r7 -= r5
> 16: (5d) if r8 != r6 goto pc+4
> 17: (c7) r8 s>>= 5
> 18: (cf) r0 s>>= r0
> 19: (7f) r0 >>= r0
> 20: (7c) w5 >>= w8 # shift-out-bounds here
> 21: exit

Here is a simplified example:

SEC("?tp")
__success __retval(0)
__naked void large_shifts(void)
{
asm volatile (" \
call %[bpf_get_prandom_u32]; \n\
r8 = r0; \n\
r6 = r0; \n\
r6 &= 0xf; \n\
if w8 < 0xffffffff goto +2; \n\
if r8 != r6 goto +1; \n\
w0 >>= w8; /* shift-out-bounds here */ \n\
exit; \n\
" :
: __imm(bpf_get_prandom_u32)
: __clobber_all);
}

The issue is caused by an invalid range assigned to R8 after R8 != R6
check, here is GDB log:

(gdb) bt
#0 scalar32_min_max_rsh ... at kernel/bpf/verifier.c:13368
#1 0xffffffff81295236 in adjust_scalar_min_max_vals ... at kernel/bpf/verifier.c:13592
#2 adjust_reg_min_max_vals .... at kernel/bpf/verifier.c:13706
#3 0xffffffff8128701f in check_alu_op ... at kernel/bpf/verifier.c:13938
#4 do_check ... at kernel/bpf/verifier.c:17327
(gdb) p *src_reg
$2 = {
type = SCALAR_VALUE,
...
smin_value = 4294967295,
smax_value = 15,
umin_value = 4294967295,
umax_value = 15,
s32_min_value = -1,
s32_max_value = -1,
u32_min_value = 4294967295,
u32_max_value = 4294967295,
...
}

The invalid range is assigned within reg_combine_min_max() function in
BPF_JNE branch. The following diff removes the error:

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 857d76694517..3d140bf85282 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -14485,7 +14485,7 @@ static void reg_combine_min_max(struct bpf_reg_state *true_src,
__reg_combine_min_max(true_src, true_dst);
break;
case BPF_JNE:
- __reg_combine_min_max(false_src, false_dst);
+ //__reg_combine_min_max(false_src, false_dst);
break;
}
}

I do not understand what BPF_JNE branch logically means in
reg_combine_min_max(), does anyone has any insight?

> After load:
> ================================================================================
> UBSAN: shift-out-of-bounds in kernel/bpf/tnum.c:44:9
> shift exponent 255 is too large for 64-bit type 'long long unsigned int'
> CPU: 2 PID: 8574 Comm: bpf-test Not tainted
> 6.6.0-rc5-01400-g7c2f6c9fb91f-dirty #21
> Hardware name: QEMU Standard PC (i440FX + PIIX, 1996), BIOS 1.15.0-1 04/01/2014
> Call Trace:
> <TASK>
> __dump_stack lib/dump_stack.c:88 [inline]
> dump_stack_lvl+0x8e/0xb0 lib/dump_stack.c:106
> ubsan_epilogue lib/ubsan.c:217 [inline]
> __ubsan_handle_shift_out_of_bounds+0x15a/0x2f0 lib/ubsan.c:387
> tnum_rshift.cold+0x17/0x32 kernel/bpf/tnum.c:44
> scalar32_min_max_rsh kernel/bpf/verifier.c:12999 [inline]
> adjust_scalar_min_max_vals kernel/bpf/verifier.c:13224 [inline]
> adjust_reg_min_max_vals+0x1936/0x5d50 kernel/bpf/verifier.c:13338
> do_check kernel/bpf/verifier.c:16890 [inline]
> do_check_common+0x2f64/0xbb80 kernel/bpf/verifier.c:19563
> do_check_main kernel/bpf/verifier.c:19626 [inline]
> bpf_check+0x65cf/0xa9e0 kernel/bpf/verifier.c:20263
> bpf_prog_load+0x110e/0x1b20 kernel/bpf/syscall.c:2717
> __sys_bpf+0xfcf/0x4380 kernel/bpf/syscall.c:5365
> __do_sys_bpf kernel/bpf/syscall.c:5469 [inline]
> __se_sys_bpf kernel/bpf/syscall.c:5467 [inline]
> __x64_sys_bpf+0x73/0xb0 kernel/bpf/syscall.c:5467
> do_syscall_x64 arch/x86/entry/common.c:50 [inline]
> do_syscall_64+0x39/0xb0 arch/x86/entry/common.c:80
> entry_SYSCALL_64_after_hwframe+0x63/0xcd
> RIP: 0033:0x5610511e23cd
> Code: 24 80 00 00 00 48 0f 42 d0 48 89 94 24 68 0c 00 00 b8 41 01 00
> 00 bf 05 00 00 00 ba 90 00 00 00 48 8d b44
> RSP: 002b:00007f5357fc7820 EFLAGS: 00000246 ORIG_RAX: 0000000000000141
> RAX: ffffffffffffffda RBX: 0000000000000095 RCX: 00005610511e23cd
> RDX: 0000000000000090 RSI: 00007f5357fc8410 RDI: 0000000000000005
> RBP: 0000000000000000 R08: 00007f5357fca458 R09: 00007f5350005520
> R10: 0000000000000000 R11: 0000000000000246 R12: 000000000000002b
> R13: 0000000d00000000 R14: 000000000000002b R15: 000000000000002b
> </TASK>
>
> If remove insn #20, the verifier gives:
> -------- Verifier Log --------
> func#0 @0
> 0: R1=ctx(off=0,imm=0) R10=fp0
> 0: (bc) w0 = w1 ;
> R0_w=scalar(smin=0,smax=umax=4294967295,var_off=(0x0; 0xffffffff))
> R1=ctx(off=0,
> imm=0)
> 1: (bf) r2 = r0 ;
> R0_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0;
> 0xffffffff))
> R2_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0; 0xffffffff))
> 2: (18) r3 = 0xd ; R3_w=13
> 4: (bc) w4 = w0 ;
> R0_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0;
> 0xffffffff))
> R4_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0; 0xffffffff))
> 5: (bf) r5 = r0 ;
> R0_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0;
> 0xffffffff))
> R5_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0; 0xffffffff))
> 6: (bf) r7 = r3 ; R3_w=13 R7_w=13
> 7: (bf) r8 = r4 ;
> R4_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0;
> 0xffffffff))
> R8_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0; 0xffffffff))
> 8: (2f) r8 *= r5 ;
> R5_w=scalar(id=1,smin=0,smax=umax=4294967295,var_off=(0x0;
> 0xffffffff))
> R8_w=scalar()
> 9: (cf) r5 s>>= r5 ; R5_w=scalar()
> 10: (a6) if w8 < 0xfffffffb goto pc+9 ;
> R8_w=scalar(smin=-9223372032559808520,umin=4294967288,smin32=-5,smax32=-1,
> umin32=4294967291,var_off=(0xfffffff8; 0xffffffff00000007))
> 11: (1f) r7 -= r5 ; R5_w=scalar() R7_w=scalar()
> 12: (71) r6 = *(u8 *)(r1 +17) ; R1=ctx(off=0,imm=0)
> R6_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=255,
> var_off=(0x0; 0xff))
> 13: (5f) r3 &= r8 ;
> R3_w=scalar(smin=umin=smin32=umin32=8,smax=umax=smax32=umax32=13,var_off=(0x8;
> 0x5)) R8_w=scalar(smin=-9223372032559808520,umin=4294967288,smin32=-5,smax32=-1,umin32=4294967291,var_off=(0xffff)
> 14: (74) w2 >>= 30 ;
> R2_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=3,var_off=(0x0;
> 0x3))
> 15: (1f) r7 -= r5 ; R5_w=scalar() R7_w=scalar()
> 16: (5d) if r8 != r6 goto pc+3 ;
> R6_w=scalar(smin=umin=umin32=4294967288,smax=umax=umax32=255,smin32=-8,smax32=-1,
> var_off=(0xfffffff8; 0x7))
> R8_w=scalar(smin=umin=4294967288,smax=umax=255,smin32=-5,smax32=-1,umin32=4294967291)
> 17: (c7) r8 s>>= 5 ; R8_w=134217727
> 18: (cf) r0 s>>= r0 ; R0_w=scalar()
> 19: (7f) r0 >>= r0 ; R0=scalar()
> 20: (95) exit
>
> from 16 to 20: safe
>
> from 10 to 20: safe
> processed 22 insns (limit 1000000) max_states_per_insn 0 total_states
> 1 peak_states 1 mark_read 1
> -------- End of Verifier Log --------
>
> In adjust_scalar_min_max_vals(), src_reg.umax_value is 7, thus pass
> the check here:
> if (umax_val >= insn_bitness) {
> /* Shifts greater than 31 or 63 are undefined.
> * This includes shifts by a negative number.
> */
> mark_reg_unknown(env, regs, insn->dst_reg);
> break;
> }
>
> However in scalar32_min_max_rsh(), both src_reg->u32_min_value and
> src_reg->u32_max_value is 134217727, causing tnum_rsh() shit by 255.
>
> Should we check if(src_reg->u32_max_value < insn_bitness) before calling
> scalar32_min_max_rsh(), rather than only checking umax_val? Or, is it
> because issues somewhere else, incorrectly setting u32_min_value to
> 34217727
>
> Best
> Hao Sun
>