Sorting Algorithm

排序算法作为最基本的算法之一,其重要性不言而喻,时至今日仍在被优化改进。

2023 年 6 月 8 日,Google DeepMind 新作 AlphaDev 发现了更快的排序算法,同时相关文章 Faster sorting algorithms discovered using deep reinforcement learning 在 Nature 发表。

有趣的是不久后 Dimitris Papailiopoulos 发推 GPT-4 “discovered” the same sorting algorithm as AlphaDev by removing “mov S P” 表示通过合适的提示词,GPT-4 也能得到同样结果,还引得马斯克围观。

C++ STL 中 std::sort 采用的是 David Musser 提出的混合排序算法,主要使用快速排序,同时记录递归深度,看情况转为堆排序,此外对短数据采用其他排序算法。具体实现可参考 Go 1.18 的 sort 包,相较 C++ 可读性更强。

LLVM 作为被广泛使用的 C++ 编译后端,其 libc++ 对于短数据采用手写 sort3 sort4 sort5 的形式,AlphaDev 对此进行了优化,相关 Commit 已被合并,这也是排序算法库十多年来第一次更新。

AlphaDev 从汇编出发,以发现更快的排序和散列算法为目标,通过强化学习解组合优化问题,根据计算耗时和结果来评估奖励,不断迭代,最后在底层发现了可优化的部分。

图为 sort3 汇编代码的前后对比,难以相信被优化了几十年,每天被调用数万亿次的排序算法,还存在着如此简单的优化方法。

大模型的推理主要是模式识别,知识整合等,而非真正的逻辑推理。但我经过验证,在指定范围的情况下,最新的 GPT-4 的确可以对上述代码进行优化,得到与论文相同的结果。此外,大模型能否作为通解代替特定程序,去优化更多微小细节,还有待研究。


十大排序算法:插入排序、希尔排序、选择排序、冒泡排序、快速排序、堆排序、归并排序、计数排序、桶排序、基数排序。

name time space stable
insertion n2n^2 11 yes
shell n2n^2 11 no
selection n2n^2 11 no
bubble n2n^2 11 yes
quick nlognn \log n logn\log n no
heap nlognn \log n logn\log n no
merge nlognn \log n nn yes
counting n+kn + k n+kn + k yes
bucket n+kn + k nknk yes
radix d(n+k)d(n + k) n+kn + k yes

相同的算法也有不同的实现,此处算法的处理对象为数组,其中希尔排序的平均时间复杂度取决于所选序列,快速排序和堆排序的空间复杂度取决于重复方式,kk 分别为计数范围、桶数、基数,dd 为阶数。

插入排序:

void insertion_sort(int arr[], int len) {
    for (int i = 1; i < len; i++) {
        int j, key = arr[i];
        for (j = i; j > 0; j--)
            if (arr[j - 1] > key)
                arr[j] = arr[j - 1];
            else
                break;
        arr[j] = key;
    }
}

希尔排序:

void shell_sort(int arr[], int len) {
    int n = 3, gap[] = {1, 3, 5};
    while (n--)
        for (int h = gap[n], i = h; i < len; i++)
            for (int j = i; j >= h; j -= h)
                if (arr[j - h] > arr[j])
                    swap(arr, j - h, j);
                else
                    break;
}

选择排序:

void selection_sort(int arr[], int len) {
    for (int i = len - 1; i > 0; i--) {
        int max = i;
        for (int j = 0; j < i; j++)
            if (arr[j] > arr[max])
                max = j;
        swap(arr, max, i);
    }
}

冒泡排序:

void bubble_sort(int arr[], int len) {
    for (int i = 0; i < len - 1; i++)
        for (int j = 0; j < len - 1 - i; j++)
            if (arr[j] > arr[j + 1])
                swap(arr, j, j + 1);
}

快速排序:

void quick(int arr[], int l, int r) {
    if (l >= r)
        return;
    int mid = arr[(l + r) / 2], i = l, j = r;
    while (i <= j) {
        while (arr[i] < mid)
            i++;
        while (arr[j] > mid)
            j--;
        if (i <= j)
            swap(arr, i, j), i++, j--;
    }
    quick(arr, l, j);
    quick(arr, i, r);
}
void quick_sort(int arr[], int len) {
    quick(arr, 0, len - 1);
}

堆排序:

void heap(int arr[], int len, int i) {
    int max = i, l = 2 * i + 1, r = 2 * i + 2;
    if (l < len && arr[max] < arr[l])
        max = l;
    if (r < len && arr[max] < arr[r])
        max = r;
    if (max != i)
        swap(arr, i, max), heap(arr, len, max);
}
void heap_sort(int arr[], int len) {
    for (int i = len / 2 - 1; i >= 0; i--)
        heap(arr, len, i);
    for (int i = len - 1; i > 0; i--)
        swap(arr, 0, i), heap(arr, i, 0);
}

归并排序:

void merge(int arr[], int tmp[], int l, int r) {
    if (l >= r)
        return;
    int m = (l + r) / 2;
    merge(arr, tmp, l, m);
    merge(arr, tmp, m + 1, r);
    int i = l, j = m + 1, k = l;
    while (i <= m && j <= r)
        tmp[k++] = arr[i] <= arr[j] ? arr[i++] : arr[j++];
    while (i <= m)
        tmp[k++] = arr[i++];
    while (j <= r)
        tmp[k++] = arr[j++];
    for (i = l; i <= r; i++)
        arr[i] = tmp[i];
}
void merge_sort(int arr[], int len) {
    int tmp[len];
    merge(arr, tmp, 0, len - 1);
}

计数排序:

void counting_sort(int arr[], int len) {
    int MIN = 0, MAX = 99;
    int num = 100;
    int cnt[num], tmp[len];
    for (int i = 0; i < num; i++)
        cnt[i] = 0;
    for (int i = 0; i < len; i++)
        cnt[arr[i]]++;
    for (int i = 1; i < num; i++)
        cnt[i] += cnt[i - 1];
    for (int i = len - 1; i >= 0; i--)
        tmp[--cnt[arr[i]]] = arr[i];
    for (int i = 0; i < len; i++)
        arr[i] = tmp[i];
}

桶排序:

void bucket_sort(int arr[], int len) {
    int MIN = 0, MAX = 99;
    int num = 10, gap = 10;
    int cnt[num], tmp[num][len];
    for (int i = 0; i < num; i++)
        cnt[i] = 0;
    for (int i = 0; i < len; i++)
        tmp[arr[i] / gap][cnt[arr[i] / gap]++] = arr[i];
    for (int i = 0; i < num; i++)
        insertion_sort(tmp[i], cnt[i]);
    for (int n = 0, i = 0; i < num; i++)
        for (int j = 0; j < cnt[i]; j++)
            arr[n++] = tmp[i][j];
}

基数排序:

void radix_sort(int arr[], int len) {
    int MIN = 0, MAX = 99;
    int num = 10;
    int cnt[num], tmp[len];
    for (int exp = 1; exp <= MAX; exp *= num) {
        for (int i = 0; i < num; i++)
            cnt[i] = 0;
        for (int i = 0; i < len; i++)
            cnt[(arr[i] / exp) % num]++;
        for (int i = 1; i < num; i++)
            cnt[i] += cnt[i - 1];
        for (int i = len - 1; i >= 0; i--)
            tmp[--cnt[(arr[i] / exp) % num]] = arr[i];
        for (int i = 0; i < len; i++)
            arr[i] = tmp[i];
    }
}