Re: [PATCH bpf-next v2] bpf: Simplify tnum_step()

From: Harishankar Vishwanathan

Date: Sun Mar 22 2026 - 22:50:12 EST


On Fri, Mar 20, 2026 at 12:23 PM Hao Sun <sunhao.th@xxxxxxxxx> wrote:

Thanks for the patch. This is a neat simplification to make the algorithm become
straight line. I really liked the idea of working with d = z - tmin,
instead of the original
where we were working with j. The core idea of adding 1 to
masked bits with fixed bits in place seems the same.

I checked this new algorithm for soundness in z3 using bitvector
theory, and it is
indeed sound and equivalent to the old algorithm.

Some comments below, mainly regarding explanation and maintainability.

> Simplify tnum_step() from a 10-variable algorithm into a straight
> line sequence of bitwise operations.

The original algorithm was purposefully verbose for the sake of
explanation. I feel it is better
use temporaries that explain relatively complex ideas. The compiler
will optimize
the subexpressions anyway.

> Problem Reduction:
>
> tnum_step(): Given a tnum `(tval, tmask)` where `tval & tmask == 0`,
> and a value `z` with `tval ≤ z < (tval | tmask)`, find the smallest
> `r > z`, a tnum-satisfying value, i.e., `r & ~tmask == tval`.
>
> Every tnum-satisfying value has the form tval | s where s is a subset
> of tmask bits (s & ~tmask == 0). Since tval and tmask are disjoint:
>
> tval | s = tval + s

I'd start with explaining what "subset of tmask" means and perhaps
define it as "submask"
early on. Currently, the term submask isn't explained anywhere but used in
the code comments.

Also, "disjoint" can be unclear to the reader. I'd explicity state
that tval and tmask can never
share 1 in the same bit position.

We should either call it "s" or "inc" everywhere.

> Similarly z = tval + d where d = z - tval, so r > z becomes:
>
> tval + s > tval + d
> s > d
>
> The problem reduces to: find the smallest s, a subset of tmask, such
> that s > d.
>
> Notice that `s` must be a subset of tmask, the problem now is simplified.

To be precise regarding minimality: we have to find a submask of t that
is minimally greater than d.

> Algorithm:
>
> The mask bits of `d` form a "counter" that we want to increment by one,
> but the counter has gaps at the fixed-bit positions. A normal +1 would
> stop at the first 0-bit it meets; we need it to skip over fixed-bit
> gaps and land on the next mask bit.

I think it will be useful to illustrate right away with an example to
make it easier
to follow.

This part of the algorithm shares its idea with the previous algorithm
and a lot of the
previous commit message explained the reasoning for doing what we are
doing, but that
explanation has been dropped.

You are first trying to identify the highest bit position k which is 1
in d, and known/fixed
in t. This is important because when constructing s (or inc), setting
to 1 any masked bit
lower than k will not result in a number > d. Only positions > k are
candidates for
being set to 1. For example

z = 1010100100
t = 10xx0x1xx1
d = 0010011011
^
k

Borrowing wording from the previous commit: We must set masked
positions of significance
higher than k. Specifically, we look for the next larger combination
of 1s to place in
the masked positions, relative to the combination that exists in d. We
can achieve this
by concatenating bits at unknown positions of t into an integer,
adding 1, and writing
the bits of that result back into the corresponding bit positions
previously extracted
from d.

An example like in the previous commit might also help:

t = 10xx0x
d = ..10.1
+ 1
----------
..11.0

> Step 1 -- plug the gaps:
>
> d | carry_mask | ~tmask
>
> - ~tmask fills all fixed-bit positions with 1.
> - carry_mask = (1 << fls64(d & ~tmask)) - 1 fills all positions
> (including mask positions) below the highest non-mask bit of d.
>
> After this, the only remaining 0s are mask bits above the highest
> non-mask bit of d where d is also 0 -- exactly the positions where
> the carry can validly land.
>
> Step 2 -- increment:
>
> (d | carry_mask | ~tmask) + 1
>
> Adding 1 flips all trailing 1s to 0 and sets the first 0 to 1. Since
> every gap has been plugged, that first 0 is guaranteed to be a mask bit
> above all non-mask bits of d.
>
> Step 3 -- mask:
>
> ((d | carry_mask | ~tmask) + 1) & tmask
>
> Strip the scaffolding, keeping only mask bits. Call the result inc.
>
> Step 4 -- result:
>
> tval | inc
>
> Reattach the fixed bits.
>
> A simple 8-bit example:

What are z and tval here in this example? It would be helpful to provide them.

> tmask: 1 1 0 1 0 1 1 0
> d: 1 0 1 0 0 0 1 0 (d = 162)
> ^
> non-mask 1 at bit 5
>
> With carry_mask = 0b00111111 (smeared from bit 5):
>
> d|carry|~tm 1 0 1 1 1 1 1 1
> + 1 1 1 0 0 0 0 0 0
> & tmask 1 1 0 0 0 0 0 0
>
> The patch passes my local test: test_verifier, test_progs for
> `-t verifier` and `-t reg_bounds`.
>
> CBMC shows the new code is equiv to original one[1], and
> a lean4 proof of correctness is available[2]:
>
> theorem tnumStep_correct (tval tmask z : BitVec 64)
> -- Precondition: valid tnum and input z
> (h_consistent : (tval &&& tmask) = 0)
> (h_lo : tval ≤ z)
> (h_hi : z < (tval ||| tmask)) :
> -- Postcondition: r must be:
> -- (1) tnum member
> -- (2) z < r
> -- (3) for any other member w > z, r <= w
> let r := tnumStep tval tmask z
> satisfiesTnum64 r tval tmask ∧
> tval ≤ r ∧ r ≤ (tval ||| tmask) ∧
> z < r ∧
> ∀ w, satisfiesTnum64 w tval tmask → z < w → r ≤ w := by
> -- unfold definition
> unfold tnumStep satisfiesTnum64
> simp only []
> refine ⟨?_, ?_, ?_, ?_, ?_⟩
> -- the solver proves each conjunct
> · bv_decide
> · bv_decide
> · bv_decide
> · bv_decide
> · intro w hw1 hw2; bv_decide
>
> [1] https://github.com/eddyz87/tnum-step-verif/blob/master/main.c
> [2] https://pastebin.com/raw/czHKiyY0
>
> Signed-off-by: Hao Sun <hao.sun@xxxxxxxxxxx>
> Acked-by: Eduard Zingerman <eddyz87@xxxxxxxxx>
> Acked-by: Shung-Hsi Yu <shung-hsi.yu@xxxxxxxx>

Reviewed-by: Harishankar Vishwanathan <harishankar.vishwanathan@xxxxxxxxx>

> ---
> v1 -> v2: inline proof, add code comments, add a variable `filled`.
>
> kernel/bpf/tnum.c | 46 +++++++++++++++++++---------------------------
> 1 file changed, 19 insertions(+), 27 deletions(-)
>
> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> index 4abc359b3db0..ec9c310cf5d7 100644
> --- a/kernel/bpf/tnum.c
> +++ b/kernel/bpf/tnum.c
> @@ -286,8 +286,7 @@ struct tnum tnum_bswap64(struct tnum a)
> */
> u64 tnum_step(struct tnum t, u64 z)
> {
> - u64 tmax, j, p, q, r, s, v, u, w, res;
> - u8 k;
> + u64 tmax, d, carry_mask, filled, inc;
>
> tmax = t.value | t.mask;
>
> @@ -299,29 +298,22 @@ u64 tnum_step(struct tnum t, u64 z)
> if (z < t.value)
> return t.value;
>
> - /* keep t's known bits, and match all unknown bits to z */
> - j = t.value | (z & t.mask);
> -
> - if (j > z) {
> - p = ~z & t.value & ~t.mask;
> - k = fls64(p); /* k is the most-significant 0-to-1 flip */
> - q = U64_MAX << k;
> - r = q & z; /* positions > k matched to z */
> - s = ~q & t.value; /* positions <= k matched to t.value */
> - v = r | s;
> - res = v;
> - } else {
> - p = z & ~t.value & ~t.mask;
> - k = fls64(p); /* k is the most-significant 1-to-0 flip */
> - q = U64_MAX << k;
> - r = q & t.mask & z; /* unknown positions > k, matched to z */
> - s = q & ~t.mask; /* known positions > k, set to 1 */
> - v = r | s;
> - /* add 1 to unknown positions > k to make value greater than z */
> - u = v + (1ULL << k);
> - /* extract bits in unknown positions > k from u, rest from t.value */
> - w = (u & t.mask) | t.value;
> - res = w;
> - }
> - return res;
> + /*
> + * Let r be the result tnum member, z = t.value + d.
> + * Every tnum member is t.value | s for some submask s of t.mask,
> + * and since t.value & t.mask == 0, t.value | s == t.value + s.
> + * So r > z becomes s > d where d = z - t.value.
> + *
> + * Find the smallest submask s of t.mask greater than d by
> + * "incrementing d within the mask": fill every non-mask
> + * position with 1 (`filled`) so +1 ripples through the gaps,
> + * then keep only mask bits. `carry_mask` additionally fills
> + * positions below the highest non-mask 1 in d, preventing
> + * it from trapping the carry.
> + */
> + d = z - t.value;
> + carry_mask = (1ULL << fls64(d & ~t.mask)) - 1;
> + filled = d | carry_mask | ~t.mask;
> + inc = (filled + 1) & t.mask;
> + return t.value | inc;
> }
> --
> 2.34.1
>