社论 22.10.9 优化连续段dp

2023-02-12,,,,


CF840C

给定一个序列 \(a\),长度为 \(n\)。试求有多少 \(1\) 到 \(n\) 的排列 \(p_i\),满足对于任意的 \(2\le i\le n\) 有 \(a_{p_{i-1}}\times a_{p_i}\) 不为完全平方数,答案对 \(10^9+7\) 取模。

\(n \le 300,a_i \le 10^9\).

连续……不是很连续 段dp。

首先考虑什么时候有限制“\(a_{p_{i-1}}\times a_{p_i}\) 为完全平方数”。

我们设 \(a^2 = b\times c,b = p\times x^2, c =q\times y^2\),其中 \(p,q\) 无平方因子。则有 $ p\times q$ = \left(\frac a {xy}\right)^2$。由于 \(p,q\) 无平方因子,因此若 \(p\times q\) 为完全平方数,则唯一的可能就是 \(p=q\)。

因此我们预先筛出 \(\le \sqrt n\) 的质数,把每个数的平方因子除去,问题就转化成了不能有相邻两个数相同的排列数量。

考虑一个插入dp。

枚举当前插入的数是哪个数,再枚举插入的位置,就可以开始分类讨论了。

    插入数和一侧的数相同
    插入数和两侧的数相同
    插入数和两侧的数都不同,两侧的数彼此不同
    插入数和两侧的数都不同,两侧的数彼此相同

然后开始设dp数组。

设 \(f[i][j][k]\) 为插入了 \(i\) 个数,已经有 \(j\) 对相同的数,\(k\) 对是由和当前插入数相同的数产生的的方案数。

设当前数插入前已经插入了 \(x\) 个和当前数相同的数,开始转移:

    \(f[i][j][k] = f[i[j][k] + f[i-1][j-1][k-1] \times (2(x - 2(k-1)) + 2(k-1))\)
    \(f[i][j][k] = f[i[j][k] + f[i-1][j-1][k-1] \times (k-1)\)
    \(f[i][j][k] = f[i[j][k] + f[i-1][j][k] \times (i-(2x-k)-(j-k))\)
    \(f[i][j][k] = f[i[j][k] + f[i-1][j+1][k] \times (j-k+1)\)

然后转移就完了。记得压一维,但不压似乎也没问题。

时间复杂度 \(O(n^3)\)。

code
#include <bits/stdc++.h>
#include <bits/extc++.h>
using namespace std;
#define rep(i, a, b) for (register int(i) = (a); (i) <= (b); ++(i))
#define pre(i, a, b) for (register int(i) = (a); (i) >= (b); --(i))
const int N = 3e2 + 10;
int n, a[N], tmp, mod = 1e9 + 7, f[2][N][N]; int prime[1000005], cnt;
bool vis[1000005];
void sieve(int bnd) {
rep(i,2,bnd) {
if (!vis[i]) prime[++cnt] = i;
rep(j,1,cnt) {
if (i * prime[j] > bnd) break;
vis[i * prime[j]] = 1;
if (i % prime[j] == 0) break;
}
}
} int main() {
ios::sync_with_stdio(false); cin.tie(0), cout.tie(0);
cin >> n; rep(i,1,n) cin >> a[i], tmp = max(tmp, a[i]);
sieve(sqrt(tmp) + 10);
rep(i,1,cnt) prime[i] = prime[i] * prime[i];
rep(i,1,n) rep(j,1,cnt) while (a[i] % prime[j] == 0) a[i] /= prime[j];
sort(a+1, a+1+n);
f[0][0][0] = 1; tmp = 0;
rep(i,1,n) {
int ptr = i & 1, ztr = ptr ^ 1;
memset(f[ptr], 0, sizeof f[ptr]);
if (a[i] != a[i-1]) {
rep(j,0,i) {
rep(k,1,tmp) {
f[ztr][j][0] = (f[ztr][j][k] + f[ztr][j][0]) % mod;
f[ztr][j][k] = 0;
}
}
tmp = 0;
}
rep(j,0,i) rep(k,0,min(tmp,j)) {
if (j > 0 and k > 0) f[ptr][j][k] = (f[ptr][j][k] + 1ll * f[ztr][j-1][k-1] * ((tmp << 1) - k + 1)) % mod;
f[ptr][j][k] = (f[ptr][j][k] + 1ll * f[ztr][j+1][k] * (j + 1 - k)) % mod;
f[ptr][j][k] = (f[ptr][j][k] + 1ll * f[ztr][j][k] * (i - ((tmp << 1) - k) - (j - k))) % mod;
} tmp++;
} cout << f[n & 1][0][0] << endl;
}

[能不能再给力一点啊?]

题面相同。

\(n \le 5000,a_i \le 10^9\).

容斥题。

为方便,设 \(s_i\) 为除去平方因子后第 \(i\) 大的数出现了多少次。

为方便,首先进行无标号计数,最后乘入 \(\prod (s_i!)\) 即可。

于是问题转化成了有颜色无标号小球排列计数,需要相邻小球颜色不同。这就比较的典。

我们考虑一个容斥。首先设有 \(b\) 段长度大于 \(1\) 的相同颜色段,然后我们把这些段缩成一个小球。设第 \(i\) 种颜色被缩成的小球数是 \(t_i\)(\(\sum t_i = b\)),则有

\[ans = \sum_{b}(-1)^{b} \frac{(n-b)!\prod_{i}\binom{s_i-1}{t_i}}{\prod_{i}(s_i - t_i)!}
\]

\((n-b)!\) 是当前缩完后的序列的总可能性,然后我们得除去那些已经在一个连续段内的排列情况。这是分子。分母考虑对连续段做一下插板法使得计数的序列都满足不会再缩起来的情况。

先不考虑 \((n-b)!\),最后得到答案时乘进去就行。套路地设 \(f_{i,j}\) 为前 \(i\) 个数值,\(b=j\) 时的答案。定义 \(n < m\) 时 \(\binom{n}{m} =0\),\(n < 0\) 时 \(n! = 0\),我们有转移

\[f_{i,j} = \sum_{k=0}^j f_{i-1,j-k} \times \frac{\binom{s_i-1}{k}}{(s_i - k)!}
\]

记得得到答案时乘进去该乘的阶乘。

code
#include <bits/stdc++.h>
#include <bits/extc++.h>
using namespace std;
#define rep(i, a, b) for (register int(i) = (a); (i) <= (b); ++(i))
#define pre(i, a, b) for (register int(i) = (a); (i) >= (b); --(i))
const int N = 300 + 10;
int n, ptr, ztr, a[N], s[N], tmp, jc[N], inv[N], mod = 1e9 + 7, f[2][N]; int prime[1000005], cnt;
bool vis[1000005];
void sieve(int bnd) {
rep(i,2,bnd) {
if (!vis[i]) prime[++cnt] = i;
rep(j,1,cnt) {
if (i * prime[j] > bnd) break;
vis[i * prime[j]] = 1;
if (i % prime[j] == 0) break;
}
}
} int qp(int a, int b) {
int ret = 1;
while (b) {
if (b & 1) ret = 1ll * ret * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
} return ret;
}
int C(int n, int m) { if (n < m) return 0; return 1ll * jc[n] * inv[m] % mod * inv[n-m] % mod; } int main() {
ios::sync_with_stdio(false); cin.tie(0), cout.tie(0);
cin >> n; rep(i,1,n) cin >> a[i], tmp = max(tmp, a[i]); sieve(sqrt(tmp) + 10);
rep(i,1,cnt) prime[i] = prime[i] * prime[i];
rep(i,1,n) rep(j,1,cnt) while (a[i] % prime[j] == 0) a[i] /= prime[j];
sort(a+1, a+1+n);
tmp = 0; cnt = 0; jc[0] = inv[0] = 1;
rep(i,1,n) jc[i] = 1ll * jc[i-1] * i % mod;
inv[n] = qp(jc[n], mod - 2);
pre(i,n-1,1) inv[i] = 1ll * inv[i+1] * (i+1) % mod; rep(i,1,n) if (a[i] != a[i-1]) ++ s[++cnt]; else ++ s[cnt];
f[0][0] = 1;
rep(i,1,cnt) {
tmp += s[i] - 1;
ptr = i & 1, ztr = ptr ^ 1;
memset(f[ptr], 0, (tmp + 1) << 2);
rep(j,0,tmp) {
rep(k,0,min(s[i], j)) {
f[ptr][j] = (f[ptr][j] + 1ll * f[ztr][j-k] * C(s[i] - 1, k) % mod * inv[s[i] - k]) % mod;
}
}
} int ans = 0; ptr = cnt & 1;
rep(i,0,tmp)
if (i & 1) ans = (ans + 1ll * (mod - 1) * jc[n - i] % mod * f[ptr][i]) % mod;
else ans = (ans + 1ll * jc[n - i] % mod * f[ptr][i]) % mod;
rep(i,1,cnt) ans = 1ll * ans * jc[s[i]] % mod;
cout << ans << endl;
}

[能不能再给力一点啊?]

题面相同。

\(n \le 10^5,a_i \le 10^9\).

生成函数优化dp题。

考虑构造两个生成函数 \(F_k\) 与 \(G_k\)。

\[F_k(x) = \sum_{i=0}f_{k,i} x^i
\]
\[G_k(x) = \sum_{i=0}^{s_k} \frac{\binom{s_k-1}{i}}{(s_k - i)!}x^i
\]

则我们能发现 \(F_k = F_{k-1} \times G_k\)。而 \(F_0 = 1\)。

则答案即为

\[\sum_{i=0}^{n-b}(-1)^{i}\times (n-i)! \times \left([x^i]\prod_{i=1}^bG_i\right)
\]

通过分治NTT得到最右边那个多项式的系数,然后算出来就行了。注意最后乘进去一个阶乘。

后面那个东西保证了度数加和是 \(n\)。所以我们开一个堆,每次挑最小的两个卷起来。总时间复杂度 \(O(n \log^2 n)\)。

code(mod = 998244353)
#include <bits/stdc++.h>
#define rep(i, a, b) for (register int(i) = (a); (i) <= (b); ++(i))
#define pre(i, a, b) for (register int(i) = (a); (i) >= (b); --(i))
using namespace std; typedef long long ll;
const int N = 3e6 + 10, mod = 998244353, g = 3;
const int siz_ll = sizeof(ll); #ifdef ONLINE_JUDGE
char buf[1<<21], *p1 = buf, *p2 = buf; inline char getc() { return (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++); }
#define getchar getc
#endif
template <typename T> inline void get(T & x){
x = 0; char ch = getchar(); bool f = false; while (ch < '0' or ch > '9') f = f or ch == '-', ch = getchar();
while (ch >= '0' and ch <= '9') x = (x<<1) + (x<<3) + (ch^48), ch = getchar(); f && (x = -x);
} template <typename T, typename ... Args> inline void get(T & x, Args & ... _Args) { get(x); get(_Args...); } int n, tmp, a[N], s[N], jc[N], inv[N]; int btrs[N]; // butterfly_transform
inline int initrs(int k) {
int limit = 1;
while (limit < k) limit <<= 1;
for (register int i = 0 ; i < limit; i ++)
btrs[i] = (btrs[i >> 1] >> 1) | ((i & 1) ? limit >> 1 : 0);
return limit;
} inline ll qp(ll a, ll b) {
ll ret = 1;
while (b) {
if (b & 1) ret = ret * a % mod;
a = a * a % mod;
b >>= 1;
} return ret;
} const int invg = qp(g, mod - 2), inv2 = qp(2, mod - 2); int L, w[2][1<<19];
inline int __INITILIZE__UNIT__ROOT__() {
L = 1<<19;
w[0][0] = w[1][0] = 1;
int wn = qp(g, (mod-1) / L);
for (register int i = 1; i < L; i++) w[0][L - i] = w[1][i] = 1ll * w[1][i - 1] * wn % mod;
return 1;
} int __INITIALIZER__ = __INITILIZE__UNIT__ROOT__(); struct poly {
vector <ll> f;
ll operator [] (const int & pos) const { return f[pos]; }
ll & operator [] (const int & pos) { return f[pos]; }
int deg() {return f.size(); }
int deg() const {return f.size(); }
void Set(int n) { f.resize(n); }
void Adjust() { while (f.size() > 1 and f.back() == 0) f.pop_back(); }
void scan(int n = -1) { if (n < 0) get(n); Set(n); for (register int i = 0; i < n; i++) get(f[i]); }
void print() { for (ll x : f) printf("%lld ", x); }
inline void NTT (const int lim, const int type) {
Set(lim);
for (register int i = 0; i < lim; i++) if (i < btrs[i]) swap(f[i], f[btrs[i]]);
for (register int mid = 1; mid < lim; mid <<= 1) {
for (register int i = L / (mid<<1), j = 0; j < lim; j += (mid << 1)) {
for (register int k = 0; k < mid; k++) {
int x = f[j + k], y = w[type][i * k] * f[j + k + mid] % mod;
f[j + k] = (x + y) % mod;
f[j + k + mid] = (x - y + mod) % mod;
}
}
} if (type == 1) return;
ll inv = qp(lim, mod - 2);
for (register int i = 0; i < lim; i++) f[i] = f[i] * inv % mod;
}
friend poly operator * (const poly & x, const poly & y) {
poly ret, A = x, B = y;
int limit = initrs(A.deg() + B.deg() - 1);
A.NTT(limit, 1), B.NTT(limit, 1); ret.Set(limit);
for (register int i = 0; i < limit; i++) ret[i] = 1ll * A[i] * B[i] % mod;
ret.NTT(limit, 0); ret.Adjust();
return ret;
} void operator *= (const poly & x) {
poly A = x;
int limit = initrs(deg() + A.deg() - 1);
A.NTT(limit, 1); NTT(limit, 1);
for (register int i = 0; i < limit; i++) f[i] = 1ll * A[i] * f[i] % mod;
NTT(limit, 0); Adjust();
}
} F[100005]; struct pip {
int id, deg;
pip(int _i, int _d) : id(_i), deg(_d) {}
bool operator < (const pip & b) const {
return deg > b.deg;
}
} ; priority_queue <pip> que; int prime[1000005], cnt;
bool vis[1000005];
void sieve(int bnd) {
rep(i,2,bnd) {
if (!vis[i]) prime[++cnt] = i;
rep(j,1,cnt) {
if (i * prime[j] > bnd) break;
vis[i * prime[j]] = 1;
if (i % prime[j] == 0) break;
}
}
} int C(int n, int m) { if (n < m) return 0; return 1ll * jc[n] * inv[m] % mod * inv[n-m] % mod; } signed main() {
ios::sync_with_stdio(false); cin.tie(0), cout.tie(0);
cin >> n; rep(i,1,n) cin >> a[i], tmp = max(tmp, a[i]); sieve(sqrt(tmp) + 10);
rep(i,1,cnt) prime[i] = prime[i] * prime[i];
rep(i,1,n) rep(j,1,cnt) while (a[i] % prime[j] == 0) a[i] /= prime[j];
sort(a+1, a+1+n);
tmp = 0; cnt = 0; jc[0] = inv[0] = 1;
rep(i,1,n) jc[i] = 1ll * jc[i-1] * i % mod;
inv[n] = qp(jc[n], mod - 2);
pre(i,n-1,1) inv[i] = 1ll * inv[i+1] * (i+1) % mod; rep(i,1,n) if (a[i] != a[i-1]) ++ s[++cnt]; else ++ s[cnt];
rep(k,1,cnt) {
F[k].Set(s[k]+1);
que.emplace(k, s[k]+1);
rep(i,0,s[k]) F[k][i] = 1ll * C(s[k] - 1, i) * inv[s[k] - i] % mod;
} tmp = n - cnt; while (que.size() > 1) {
int x1 = que.top().id; que.pop();
int x2 = que.top().id; que.pop();
F[x1] = F[x1] * F[x2];
que.emplace(x1, F[x1].deg());
F[x2].f.clear(); F[x2].f.shrink_to_fit();
} int ans = 0, ptr = que.top().id;
rep(i,0,tmp)
if (i & 1) ans = (ans + 1ll * (mod - 1) * jc[n - i] % mod * F[ptr][i]) % mod;
else ans = (ans + 1ll * jc[n - i] % mod * F[ptr][i]) % mod;
rep(i,1,cnt) ans = 1ll * ans * jc[s[i]] % mod;
cout << ans << endl;
}
code(mod = any)
#include <bits/stdc++.h>
#define rep(i, a, b) for (register int(i) = (a); (i) <= (b); ++(i))
#define pre(i, a, b) for (register int(i) = (a); (i) >= (b); --(i))
using namespace std;
const int N = 3e5 + 10, mod = 1e9 + 7;
int n, tmp, a[N], s[N], jc[N], inv[N]; #define fp(i, a, b) for (int i = (a), i##_ = (b) + 1; i < i##_; ++i)
#define fd(i, a, b) for (int i = (a), i##_ = (b) - 1; i > i##_; --i)
using ll = int64_t;
using db = double;
struct cp {
db x, y;
cp(db real = 0, db imag = 0) : x(real), y(imag){};
cp operator+(cp b) const { return {x + b.x, y + b.y}; }
cp operator-(cp b) const { return {x - b.x, y - b.y}; }
cp operator*(cp b) const { return {x * b.x - y * b.y, x * b.y + y * b.x}; }
};
using vcp = vector<cp>;
using Poly = vector<int>;
namespace FFT {
const db pi = acos(-1);
vcp Omega(int L) {
vcp w(L); w[1] = 1;
for (int i = 2; i < L; i <<= 1) {
auto w0 = w.begin() + i / 2, w1 = w.begin() + i;
cp wn(cos(pi / i), sin(pi / i));
for (int j = 0; j < i; j += 2)
w1[j] = w0[j >> 1], w1[j + 1] = w1[j] * wn;
}
return w;
}
auto W = Omega(1 << 21);
void DIF(cp *a, int n) {
cp x, y;
for (int k = n >> 1; k; k >>= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j)
x = a[i + j], y = a[i + j + k],
a[i + j + k] = (x - y) * W[k + j], a[i + j] = x + y;
}
void IDIT(cp *a, int n) {
cp x, y;
for (int k = 1; k < n; k <<= 1)
for (int i = 0; i < n; i += k << 1)
for (int j = 0; j < k; ++j)
x = a[i + j], y = a[i + j + k] * W[k + j],
a[i + j + k] = x - y, a[i + j] = x + y;
const db Inv = 1. / n;
fp(i, 0, n - 1) a[i].x *= Inv, a[i].y *= Inv;
reverse(a + 1, a + n);
}
}
namespace MTT{
Poly conv(const Poly &a, const Poly &b, const int&P) {
int n = a.size(), m = b.size(), o = n + m - 1, l = 1 << (__lg(o - 1) + 1);
vcp A(l), B(l), c0(l), c1(l);
for (int i = 0; i < n; i++) A[i] = cp(a[i] & 0x7fff, a[i] >> 15);
for (int i = 0; i < m; i++) B[i] = cp(b[i] & 0x7fff, b[i] >> 15);
FFT::DIF(A.data(), l), FFT::DIF(B.data(), l);
for (int k = 1, i = 0, j; k < l; k <<= 1)
for (; i < k * 2; i++) {
j = i ^ k - 1;
c0[i] = cp(A[i].x + A[j].x, A[i].y - A[j].y) * B[i] * 0.5;
c1[i] = cp(A[i].y + A[j].y, -A[i].x + A[j].x) * B[i] * 0.5;
}
FFT::IDIT(c0.data(), l), FFT::IDIT(c1.data(), l);
Poly res(o);
for (int i = 0; i < o; i++) {
ll c00 = c0[i].x + 0.5, c01 = c0[i].y + 0.5, c10 = c1[i].x + 0.5, c11 = c1[i].y + 0.5;
res[i] = (c00 + ((c01 + c10) % P << 15) + (c11 % P << 30)) % P;
}
return res;
}
}
namespace Polynomial {
void DFT(vcp &a) { FFT::DIF(a.data(), a.size()); }
void IDFT(vcp &a) { FFT::IDIT(a.data(), a.size()); }
int norm(int n) { return 1 << (__lg(n - 1) + 1); } vcp &dot(vcp &a, vcp &b) { fp(i, 0, a.size() - 1) a[i] = a[i] * b[i]; return a; }
Poly operator*(Poly &a, Poly &b) {
int n = a.size() + b.size() - 1;
vcp c(norm(n));
fp(i, 0, a.size() - 1) c[i].x = a[i];
fp(i, 0, b.size() - 1) c[i].y = b[i];
DFT(c), dot(c, c), IDFT(c), a.resize(n);
fp(i, 0, n - 1) a[i] = int(c[i].y * .5 + .5);
return a;
}
}
using namespace Polynomial; Poly F[N]; int qp(int a, int b) {
int ret = 1;
while (b) {
if (b & 1) ret = 1ll * ret * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
} return ret;
} struct pip {
int id, deg;
pip(int _i, int _d) : id(_i), deg(_d) {}
bool operator < (const pip & b) const {
return deg > b.deg;
}
} ; priority_queue <pip> que; int prime[1000005], cnt;
bool vis[1000005];
void sieve(int bnd) {
rep(i,2,bnd) {
if (!vis[i]) prime[++cnt] = i;
rep(j,1,cnt) {
if (i * prime[j] > bnd) break;
vis[i * prime[j]] = 1;
if (i % prime[j] == 0) break;
}
}
} int C(int n, int m) { if (n < m) return 0; return 1ll * jc[n] * inv[m] % mod * inv[n-m] % mod; } signed main() {
ios::sync_with_stdio(false); cin.tie(0), cout.tie(0);
cin >> n; rep(i,1,n) cin >> a[i], tmp = max(tmp, a[i]); sieve(sqrt(tmp) + 10);
rep(i,1,cnt) prime[i] = prime[i] * prime[i];
rep(i,1,n) rep(j,1,cnt) while (a[i] % prime[j] == 0) a[i] /= prime[j];
sort(a+1, a+1+n);
tmp = 0; cnt = 0; jc[0] = inv[0] = 1;
rep(i,1,n) jc[i] = 1ll * jc[i-1] * i % mod;
inv[n] = qp(jc[n], mod - 2);
pre(i,n-1,1) inv[i] = 1ll * inv[i+1] * (i+1) % mod; rep(i,1,n) if (a[i] != a[i-1]) ++ s[++cnt]; else ++ s[cnt];
rep(k,1,cnt) {
F[k].resize(s[k]+1);
que.emplace(k, s[k]+1);
rep(i,0,s[k]) F[k][i] = 1ll * C(s[k] - 1, i) * inv[s[k] - i] % mod;
} tmp = n - cnt; while (que.size() > 1) {
int x1 = que.top().id; que.pop();
int x2 = que.top().id; que.pop();
F[x1] = MTT :: conv(F[x1], F[x2], mod);
que.emplace(x1, F[x1].size());
F[x2].clear(); F[x2].shrink_to_fit();
} int ans = 0, ptr = que.top().id;
rep(i,0,tmp)
if (i & 1) ans = (ans + 1ll * (mod - 1) * jc[n - i] % mod * F[ptr][i]) % mod;
else ans = (ans + 1ll * jc[n - i] % mod * F[ptr][i]) % mod;
rep(i,1,cnt) ans = 1ll * ans * jc[s[i]] % mod;
cout << ans << endl;
}

进行了一下测试,\(n = 3\times 10^5, a[i]\in[1,n]\) 的点开-O2后是0.6s。


关于“能不能再给力一点啊?”这件事,先咕着。

社论 22.10.9 优化连续段dp的相关教程结束。

《社论 22.10.9 优化连续段dp.doc》

下载本文的Word格式文档,以方便收藏与打印。