Re: [PATCH v2 -next] sched/cputime: Fix mul_u64_u64_div_u64() precision for cputime

From: Oleg Nesterov
Date: Fri Jul 26 2024 - 06:46:32 EST


On 07/26, Zheng Zucheng wrote:
>
> before call mul_u64_u64_div_u64(),
> stime = 175136586720000, rtime = 135989749728000, utime = 1416780000.

So stime + utime == 175138003500000

> after call mul_u64_u64_div_u64(),
> stime = 135989949653530

Hmm. On x86 mul_u64_u64_div_u64(175136586720000, 135989749728000, 175138003500000)
returns 135989749728000 == rtime, see below.

Nevermind...

> --- a/kernel/sched/cputime.c
> +++ b/kernel/sched/cputime.c
> @@ -582,6 +582,12 @@ void cputime_adjust(struct task_cputime *curr, struct prev_cputime *prev,
> }
>
> stime = mul_u64_u64_div_u64(stime, rtime, stime + utime);
> + /*
> + * Because mul_u64_u64_div_u64() can approximate on some
> + * achitectures; enforce the constraint that: a*b/(b+c) <= a.
> + */
> + if (unlikely(stime > rtime))
> + stime = rtime;

Thanks,

Acked-by: Oleg Nesterov <oleg@xxxxxxxxxx>

-------------------------------------------------------------------------------
But perhaps it makes sense to improve the accuracy of mul_u64_u64_div_u64() ?
See the new() function in the code below.

Say, with the numbers above I get

$ ./test 175136586720000 135989749728000 175138003500000
old -> 135989749728000 e=1100089950.609375
new -> 135988649638050 e=0.609375

Oleg.

-------------------------------------------------------------------------------
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

typedef unsigned long long u64;

static inline int fls64(u64 x)
{
int bitpos = -1;
/*
* AMD64 says BSRQ won't clobber the dest reg if x==0; Intel64 says the
* dest reg is undefined if x==0, but their CPU architect says its
* value is written to set it to the same as before.
*/
asm("bsrq %1,%q0"
: "+r" (bitpos)
: "rm" (x));
return bitpos + 1;
}

static inline int ilog2(u64 n)
{
return fls64(n) - 1;
}

#define swap(a, b) \
do { typeof(a) __tmp = (a); (a) = (b); (b) = __tmp; } while (0)

static inline u64 div64_u64_rem(u64 dividend, u64 divisor, u64 *remainder)
{
*remainder = dividend % divisor;
return dividend / divisor;
}
static inline u64 div64_u64(u64 dividend, u64 divisor)
{
return dividend / divisor;
}

//-----------------------------------------------------------------------------

// current implementation of mul_u64_u64_div_u64
u64 old(u64 a, u64 b, u64 c)
{
u64 res = 0, div, rem;
int shift;

/* can a * b overflow ? */
if (ilog2(a) + ilog2(b) > 62) {
/*
* Note that the algorithm after the if block below might lose
* some precision and the result is more exact for b > a. So
* exchange a and b if a is bigger than b.
*
* For example with a = 43980465100800, b = 100000000, c = 1000000000
* the below calculation doesn't modify b at all because div == 0
* and then shift becomes 45 + 26 - 62 = 9 and so the result
* becomes 4398035251080. However with a and b swapped the exact
* result is calculated (i.e. 4398046510080).
*/
if (a > b)
swap(a, b);

/*
* (b * a) / c is equal to
*
* (b / c) * a +
* (b % c) * a / c
*
* if nothing overflows. Can the 1st multiplication
* overflow? Yes, but we do not care: this can only
* happen if the end result can't fit in u64 anyway.
*
* So the code below does
*
* res = (b / c) * a;
* b = b % c;
*/
div = div64_u64_rem(b, c, &rem);
res = div * a;
b = rem;

shift = ilog2(a) + ilog2(b) - 62;
if (shift > 0) {
/* drop precision */
b >>= shift;
c >>= shift;
if (!c)
return res;
}
}

return res + div64_u64(a * b, c);
}

u64 new(u64 a, u64 b, u64 c)
{
u64 res = 0, div, rem;

/* can a * b overflow ? */
while (ilog2(a) + ilog2(b) > 62) {
if (a > b)
swap(b, a);

if (b >= c) {
/*
* (b * a) / c is equal to
*
* (b / c) * a +
* (b % c) * a / c
*
* if nothing overflows. Can the 1st multiplication
* overflow? Yes, but we do not care: this can only
* happen if the end result can't fit in u64 anyway.
*
* So the code below does
*
* res += (b / c) * a;
* b = b % c;
*/
div = div64_u64_rem(b, c, &rem);
res += div * a;
b = rem;
continue;
}

/* drop precision */
b >>= 1;
c >>= 1;
if (!c)
return res;
}

return res + div64_u64(a * b, c);
}

int main(int argc, char **argv)
{
u64 a, b, c, ro, rn;
double rd;

assert(argc == 4);
a = strtoull(argv[1], NULL, 0);
b = strtoull(argv[2], NULL, 0);
c = strtoull(argv[3], NULL, 0);

rd = (((double)a) * b) / c;
ro = old(a, b, c);
rn = new(a, b, c);

printf("old -> %lld\te=%f\n", ro, ro - rd);
printf("new -> %lld\te=%f\n", rn, rn - rd);

return 0;
}