Re: [PATCH] mm,vmscan: fix divide by zero in get_scan_count

From: Johannes Weiner
Date: Mon Aug 30 2021 - 16:46:21 EST


On Thu, Aug 26, 2021 at 10:01:49PM -0400, Rik van Riel wrote:
> Changeset f56ce412a59d ("mm: memcontrol: fix occasional OOMs due to
> proportional memory.low reclaim") introduced a divide by zero corner
> case when oomd is being used in combination with cgroup memory.low
> protection.
>
> When oomd decides to kill a cgroup, it will force the cgroup memory
> to be reclaimed after killing the tasks, by writing to the memory.max
> file for that cgroup, forcing the remaining page cache and reclaimable
> slab to be reclaimed down to zero.
>
> Previously, on cgroups with some memory.low protection that would result
> in the memory being reclaimed down to the memory.low limit, or likely not
> at all, having the page cache reclaimed asynchronously later.
>
> With f56ce412a59d the oomd write to memory.max tries to reclaim all the
> way down to zero, which may race with another reclaimer, to the point of
> ending up with the divide by zero below.
>
> This patch implements the obvious fix.
>
> Fixes: f56ce412a59d ("mm: memcontrol: fix occasional OOMs due to proportional memory.low reclaim")
> Signed-off-by: Rik van Riel <riel@xxxxxxxxxxx>

That took me a second.

Before the patch, that sc->memcg_low_reclaim test was outside of that
whole proportional reclaim branch. So if we were in low reclaim mode
we wouldn't even check if a low setting is in place; if min is zero,
we don't enter the proportional branch.

Now we enter if low is set but ignored, and then end up with
cgroup_size == min == 0 == divide by black hole.

Good catch.

Acked-by: Johannes Weiner <hannes@xxxxxxxxxxx>

> diff --git a/mm/vmscan.c b/mm/vmscan.c
> index eeae2f6bc532..f1782b816c98 100644
> --- a/mm/vmscan.c
> +++ b/mm/vmscan.c
> @@ -2592,7 +2592,7 @@ static void get_scan_count(struct lruvec *lruvec, struct scan_control *sc,
> cgroup_size = max(cgroup_size, protection);
>
> scan = lruvec_size - lruvec_size * protection /
> - cgroup_size;
> + (cgroup_size + 1);

I have no overly strong preferences, but if Michal prefers max(), how about:

cgroup_size = max3(cgroup_size, protection, 1);

Or go back to not taking the branch in the first place when there is
no protection in effect...

diff --git a/mm/vmscan.c b/mm/vmscan.c
index 6247f6f4469a..9c200bb3ae51 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -2547,7 +2547,7 @@ static void get_scan_count(struct lruvec *lruvec, struct scan_control *sc,
mem_cgroup_protection(sc->target_mem_cgroup, memcg,
&min, &low);

- if (min || low) {
+ if (min || (!sc->memcg_low_reclaim && low)) {
/*
* Scale a cgroup's reclaim pressure by proportioning
* its current usage to its memory.low or memory.min