[PATCH v2] platform/chrome: sensorhub: Implement quickselect for median calculation

From: Kuan-Wei Chiu
Date: Fri Nov 10 2023 - 13:17:17 EST


The cros_ec_sensor_ring_median function currently uses an inefficient
sorting algorithm (> O(n)) to find the median of an array. This patch
replaces the sorting approach with the quickselect algorithm, which
achieves an average time complexity of O(n).

The algorithm employs the median-of-three rule to select the pivot,
mitigating worst-case scenarios and reducing the expected number of
necessary comparisons. This strategy enhances the algorithm's
efficiency and ensures a more balanced partitioning.

In the worst case, the runtime of quickselect could regress to O(n^2).
To address this, alternative algorithms like median-of-medians that
can guarantee O(n) even in the worst case. However, due to higher
overhead and increased complexity of implementation, quickselect
remains a pragmatic choice for our use case.

Signed-off-by: Kuan-Wei Chiu <visitorckw@xxxxxxxxx>
---
v1 -> v2:
* Separate patch series into two patches.
* Modify the microbenchmark[1] to set n=64 and run 10000 repeated times.
* Enhance coding style and comments.

[1]:
static void init_array(s64 *arr, size_t length, s64 seed)
{
for (int i = 0; i < length; i++) {
seed = (seed * 725861) % 6599;
arr[i] = seed;
}
}

static int quickselect_test(void)
{
s64 *arr;
s64 median_old, median_new;
ktime_t start, end;
s64 delta, time_old = 0, time_new = 0;
const size_t array_length = 64;
const size_t round = 10000;

arr = kmalloc(array_length * sizeof(s64), GFP_KERNEL);
if (!arr)
return -ENOMEM;

for(size_t i = 0; i < round; i++) {
init_array(arr, array_length, i + 1);
start = ktime_get();
median_old = cros_ec_sensor_ring_median(arr, array_length);
end = ktime_get();
delta = ktime_us_delta(end, start);
time_old += delta;

init_array(arr, array_length, i + 1);
start = ktime_get();
median_new = cros_ec_sensor_ring_median_new(arr, array_length);
end = ktime_get();
delta = ktime_us_delta(end, start);
time_new += delta;

if(median_old != median_new)
return 1;
}

printk(KERN_ALERT "Total time of original function: %lld\n", time_old);
printk(KERN_ALERT "Total time of new function: %lld\n", time_new);

kfree(arr);

/* return 0 on success */
return 0;
}

/* Result:
* Total time of original function: 157561
* Total time of new function: 1480
*/

.../platform/chrome/cros_ec_sensorhub_ring.c | 62 ++++++++++++++-----
1 file changed, 45 insertions(+), 17 deletions(-)

diff --git a/drivers/platform/chrome/cros_ec_sensorhub_ring.c b/drivers/platform/chrome/cros_ec_sensorhub_ring.c
index 9e17f7483ca0..1205219515d6 100644
--- a/drivers/platform/chrome/cros_ec_sensorhub_ring.c
+++ b/drivers/platform/chrome/cros_ec_sensorhub_ring.c
@@ -133,33 +133,61 @@ int cros_ec_sensorhub_ring_fifo_enable(struct cros_ec_sensorhub *sensorhub,
return ret;
}

-static int cros_ec_sensor_ring_median_cmp(const void *pv1, const void *pv2)
+static void cros_ec_sensor_ring_median_swap(s64 *a, s64 *b)
{
- s64 v1 = *(s64 *)pv1;
- s64 v2 = *(s64 *)pv2;
-
- if (v1 > v2)
- return 1;
- else if (v1 < v2)
- return -1;
- else
- return 0;
+ s64 tmp = *a;
+ *a = *b;
+ *b = tmp;
}

/*
* cros_ec_sensor_ring_median: Gets median of an array of numbers
*
- * For now it's implemented using an inefficient > O(n) sort then return
- * the middle element. A more optimal method would be something like
- * quickselect, but given that n = 64 we can probably live with it in the
- * name of clarity.
+ * It's implemented using the quickselect algorithm, which achieves an
+ * average time complexity of O(n) the middle element. In the worst case,
+ * the runtime of quickselect could regress to O(n^2). To mitigate this,
+ * algorithms like median-of-medians exist, which can guarantee O(n) even
+ * in the worst case. However, these algorithms come with a higher
+ * overhead and are more complex to implement, making quickselect a
+ * pragmatic choice for our use case.
*
- * Warning: the input array gets modified (sorted)!
+ * Warning: the input array gets modified!
*/
static s64 cros_ec_sensor_ring_median(s64 *array, size_t length)
{
- sort(array, length, sizeof(s64), cros_ec_sensor_ring_median_cmp, NULL);
- return array[length / 2];
+ int lo = 0;
+ int hi = length - 1;
+
+ while (lo <= hi) {
+ int mid = lo + (hi - lo) / 2;
+ int pivot, i;
+
+ if (array[lo] > array[mid])
+ cros_ec_sensor_ring_median_swap(&array[lo], &array[mid]);
+ if (array[lo] > array[hi])
+ cros_ec_sensor_ring_median_swap(&array[lo], &array[hi]);
+ if (array[mid] < array[hi])
+ cros_ec_sensor_ring_median_swap(&array[mid], &array[hi]);
+
+ pivot = array[hi];
+ i = lo - 1;
+
+ for (int j = lo; j < hi; j++)
+ if (array[j] < pivot)
+ cros_ec_sensor_ring_median_swap(&array[++i], &array[j]);
+
+ /* The pivot's index corresponds to i+1. */
+ cros_ec_sensor_ring_median_swap(&array[i + 1], &array[hi]);
+ if (i + 1 == length / 2)
+ return array[i + 1];
+ if (i + 1 > length / 2)
+ hi = i;
+ else
+ lo = i + 2;
+ }
+
+ /* Should never reach here. */
+ return -1;
}

/*
--
2.25.1