算法学习笔记(3): 倍增与ST算法

2023-02-23,,

倍增

目录
倍增
查找 洛谷P2249
重点
变式练习
快速幂
ST表
扩展 - 运算
扩展 - 区间
变式答案

倍增,字面意思即”成倍增长“

他与二分十分类似,都是基于”2“的划分思想

那么具体是怎么样,我们以一个例子来看

ST表才是文章的重点 QwQ


查找 洛谷P2249

依据题面,我们知道这是一个单调序列,当然可以通过二分的方式来寻找答案,但是既然我们这里讲倍增,那么就用倍增来写吧!

首先,我们先贴上核心代码

void find(int k) {
int i = 0, p = 1;
while (p) {
if (i + p < n && a[i + p] < k) i += p, p <<= 1;
else p >>= 1;
} if (a[i + 1] == k)
printf("%d ", i + 1);
else
printf("-1 ");
}

其中i表示所寻找的下标,p表示步长。

算法步骤如下:

    保证i + p没有超过上界并比较a[i + p]k的大小关系,如果小于k,证明最终答案必定在i之后,所以将i设为i + p,并将步长p乘以2;否则,将步长p除以2

    重复上一步,直到步长p == 0,此时,a[i]为严格小于k的最后一个数。

    如果a[i + 1]不为k,则k不存在于数组中,输出-1;否则,输出i + 1

其实不难发现,其实这种代码比而二分的代码简洁了很多,所以我很喜欢用倍增

了解了上述步骤,我们可以发现,倍增的思想体现在步长之上,那为什么步长关于2的变换时正确的呢?

其实我们很容易知道,每一个数都可以以二进制数表示,而这里的步长从某种意义上来说相当于对于数的每一个二进制位的修改。即是用了“二进制划分”的思想。


重点

像上面代码写的倍增最终i的位置是最后一个满足if后的条件的位置


变式练习

如果我们把问题改为寻找最后一次出现的位置呢?这时算法该如何书写?

参考代码见文末


快速幂

其实,从上面的例子中我们已经对于倍增的思想有了一些体会。

实际上,“倍增”与“二进制划分”两个思想相互结合,才碰撞出了不一样的烟火。如这里的快速幂。

快速幂可以参考这篇文章:算法学习笔记(4):快速幂 - 知乎 (zhihu.com)

但是,在这篇文章的讲述中,快速幂的递归形式实际上时使用了二分的思想。而只有递推的形式才属于倍增的思想。

其实这里我们可以看出倍增与二分的联系:倍增类似于二分的逆过程,当然,这并不准确。

上面链接所给文章中快速幂讲述的十分清楚,甚至有额外的拓展,所以就不再详细展开。

这里给出一个快速幂的参考代码

// (a**x) % p
int quickPow(int a, int x, int p) {
int r = 1;
while (x) {
// no need to use quickMul when p*p can be smaller than int64.max !!!
if (x & 1) r = (r * a) % p;
a = (a * a) % p, x >>= 1;
}
return r;
}

ST表

在RMQ(区间最值)问题中,著名的ST算法就是倍增的产物。ST算法可以在\(O(N\,log\,N)\)的时间复杂度能预处理后,以\(O(1)\)的复杂度在线回答区间[l, r]内的最值。

当然,ST表不支持动态修改,如果需要动态修改,线段树是一种良好的解决方案,也是\(O(N\,log\,N)\)的时间复杂度,但是查询需要\(O(logN)\)的时间复杂度

那么ST表中倍增的思想是如何体现的呢?

一个序列的子区间明显有\(N^2\)个,根据倍增的思想,我们在这么多个子区间中选择一些长度为\(2\)的整数次幂的区间作为代表值。

设\(st[i][j]\)表示子区间\([i, i+2^j)\)里最大的数

也可以表示为\([i, i + 2^j -1 ]\),无论如何,其中有\(2^j\)个元素

下文中的\(a\)表示原序列

递推边界明显是\(st[i][0] = a[i]\)。

于是,根据成倍增长的长度,有了递推公式

\[st[i][j] = max(st[i][j-1],\;st[i+2^{j-1}][j-1])
\]

当询问任意区间\([l, r]\)的最值时,我们先计算出一个最大的\(k\)满足:\(2^k \le r - l + 1\),即需要不大于区间长度。那么,由于二进制划分我们可以知道,这个最大的k一定满足\(2^{k+1}\ge r-l+1\),即我们只需要将两个长度为\(2^k\)的区间合并即可。

又根据max(a, a) = a可以知道,重复计算区间是没有任何问题的。

所以,在寻找最值的时候就有了以下公式:

\[max(a[l, r]) = max(st[l][k], st[r-2^k + 1][k])
\]

那么这里给出一种参考代码

// 啊,写这种预处理以2位底的对数的整数值的方式
// 我主要是为了将代码模块化,做到低耦合度
// 完全是可以分开来写的
class Log2Factory {
private:
int lg2[N];
public:
void init(int n) {
for (int i = 2; i <= n; ++i) lg2[i] = lg2[i >> 1] + 1;
} // 重载()运算符
int operator() (const int &i) {
return lg2[i];
}
}; template<typename T>
class STable {
private: typedef T(*OP_FUNC)(T, T);
Log2Factory Log2;
T f[N][17]; // maybe most of the times k=17 is ok, make sure 2^k greater than N;
OP_FUNC op;
public:
void setOp(OP_FUNC fc) {
op = fc;
} void init(T *a, int n) {
for (int i = 1; i <= n; ++i)
f[i][0] = *(++a); int t = Log2(n);
// f[i][k] is the interval of [i, i + 2^k - 1]
// so f[i][k] can equal to the op sum of [i, i^k - 1]
// let r = i^k - 1
// => f[r - (1^k) + 1][k] can equal to the op sum of [i][k]
for (int k = 1; k <= t; ++k) {
for (int i = 1; i + (1<<k) - 1 <= n; ++i)
f[i][k] = op(f[i][k-1], f[i + (1<<(k-1))][k-1]);
}
} const T query(int l, int r) {
int k = Log2(r - l + 1);
return op(f[l][k], f[r - (1<<k) + 1][k]);
}
};

这……写法很神奇,注意修改!

扩展 - 运算

ST算法不仅仅是可以求区间的最值的,只要时满足静态的,满足区间加法的问题大多数情况都可以通过ST表实现。

那么区间加法是什么意思呢?

定义我们需要对数列的筛选函数为op,则需要op满足以下性质

op(a, a) = a,即重复参与运算不改变最终影响

op(a, b) = op(b, a),即满足交换律

op(a, op(b, c)) = op(op(a, b), c),即满足结合律

举个例子,如果我们求区间是否有负数,可以将op设为如下逻辑:

bool op(bool a, bool b) {
return a | b;
}

相应的,初始化的方式也需要更改

if (a[i] < 0) st[i][0] = true;
else st[i][0] = false;

再举一个例子,如果我们需要求区间是否全为偶数时,则初始化为

if (a[i] % 2 == 0) st[i][0] = true;
else st[i][0] = false;

操作op定义为

bool op(bool a, bool b) {
return a & b;
}

由此可见,其实ST算法可以做到的不仅仅是区间最值那么普通的东西啊。

但是,由于加法不满足性质一,所以,ST表通过这种方法并不能求得区间的所有满足某种性质的元素的个数。但是,通过另外一种query方式,我们可以做到这样。

扩展 - 区间

那么这个部分我们将讨论如何利用ST表做到上文例子中求区间偶数的个数。

同样,由于我们可以通过二进制划分,所以可以将某一个区间长度转化为多个长度为2的整数幂次方的子区间,并且可以保证这些区间不相互重叠

其实这是借鉴了一点线段树的思路

那么可以写出以下代码

int query(int l, int r) {
if (l == r) return st[l][0];
int k = log2(r - l + 1);
return op(st[l][k], query(l + (1<<k), r))
}

这样就满足了区间不重叠

或许会有一个问题,为什么初始化的时候不需要修改?

其实不难发现,初始化的合并是不会有重复贡献的情况的,即是每一次合并的区间是不会重叠的


变式答案

其实非常类似的!

void find(int k) {
int i = 0, p = 1;
while (p) {
if (i + p <= n && a[i + p] <= k) i += p, p <<= 1;
else p >>= 1;
} if (a[i] == k)
printf("%d ", i);
else
printf("-1 ");
}

算法学习笔记(3): 倍增与ST算法的相关教程结束。

《算法学习笔记(3): 倍增与ST算法.doc》

下载本文的Word格式文档,以方便收藏与打印。