「CF1349F2」Slime and Sequences (Hard Version)
定义一个排列 p 是好的当且仅当对于每个 k<max{p},存在 1≤i<j≤n 使得 ai=k−1 且 aj=k。
定义 fa(k) 为序列 a 中数值 k 的出现次数,假设所有合法序列集合为 S,对于每个 k∈[1;n],求
(a∈S∑fa(k))mod998244353
n≤105。
题解
考虑由排列 p 生成序列,在 pi 和 pi+1 之间填入大于或小于号,则 api 的值为 p1...i 中的小于号个数 +1,不难验证合法序列集合和排列集合构成双射。
注意到这是个欧拉数的形式,考虑容斥维护:
ansi=m=0∑nm!n!j≥i∑(m−j)!(−1)j−i(ij){mm−j}=j≥i∑n!(−1)j−i(ij)m=j+1∑n{mm−j}m!(m−j)!
第二类斯特林数的生成函数
{nk}=n![zn]k!(ez−1)k
设 h,先带上 m=i 的情况计算,最后再令 h0 减 1。
hi=m=i∑n{mm−i}m!(m−i)!=m=i∑n[zm](ez−1)m−i=j=0∑n−i[zj+i](ez−1)j=[zi]j=0∑n−i(zez−1)j
令 F(z)=(ez−1)/z,有
hi=[zi]j=0∑n−iFj(z)=[zi]F(z)−1Fn−i+1(z)−1=[zi](F(z)−1−1+F(z)−1Fn−i+1(z))
前半部分容易直接多项式求逆处理,现在考虑后半部分
[zi]F(z)−1Fn−i+1(z)=[zn+1]F(z)−1(zF(z))n−i+1
(开始在这里卡住了,不求甚解的从题解那里拉了个式子,第二天冷静了一下,为了行文连贯,把补充理解附在后面)
设 ω(z)=zF(z),φ(z) 满足 φ(ω(z))ω(z)=z,则 φ(ω(z))zF(z)=z,即 F=φ(ω)
拓展拉格朗日反演
f,g,h 是 F[[x]] 上的多项式,已知 f(g(x))=g(f(x))=x,则
[xn]h(f(x))=n1[xn−1]h′(x)(g(x)x)n
考虑
f(x)g(x)h(x)=ω(x)=φ(x)x=(1−φ(x))(1−ux)1
则
=[un−i+1zn+1]1−φ(ω(z))11−uw(z)1[un−i+1]n+11[zn]((1−φ(z)11−uz1)′⋅φ(z)n+1)
(1−x)k1=i=0∑∞(k−1i+k−1)xi
直接考虑系数组合意义就可以证明。
其中
==(1−φ(z)11−uz1)′(1−uz)(1−φ(z))2φ′(z)+(1−φ(z))(1−uz)2u(1−φ(z))2φ′(z)i=0∑∞uizi+1−φ(z)1i=0∑∞(i+1)ui+1zi
代入到原式中有
=n+1[zn]φn+1(z)((1−φ(z))2zn−i+1φ′(z)+1−φ(z)(n−i+1)zn−i)[zi−1](n+1)(1−φ(z))2φn+1(z)φ′(z)+[zi](n+1)(1−φ(z))φn+1(z)(n−i+1)
唯一的问题就是 φ(x) 怎么求了,注意到 ω(x)=ex−1,构造得 ln(1+z)z。
坑点
- φ(z)=ln(1+z)z,注意到 ln(1+z) 的常数项是 0,不能直接求逆;
- 1−φ(z) 常数项也是 0,同样不能直接求逆,只能求出 (1−φ(z))/z,上面的式子大概变成
[zi+1](n+1)((1−φ(z))/z)2φn+1(z)φ′(z)+[zi+1](n+1)((1−φ(z)/z))φn+1(z)(n−i+1)
- 上面两种情况的处理可能带来更多的多项式长度要求。
- h0 减 1。
补充
现在我们需要对 i∈[0;n],求出
[zn+1]F(z)−1(zF(z))n−i+1
考虑用二元生成函数表示
=[zn+1un−i+1]F(z)−11k=0∑∞(zuF(z))k[zn+1un−i+1](F(z)−1)(1−zuF(z))1
如果直接对着这个式子做,是没有办法拉格朗日反演的,因为本质上我们有两个关于 z 的多项式:F(z) 和 zF(z)。我们设法构造一个关于 F 的多项式使其满足 φ(F(z))=zF(z)。
形式化的,我们想要构造多项式 φ(z) 使得
φ(F(z))=zF(z)
同时,由该式我们也可以直接得到 F 的复合逆形式:
F(z)φ(F(z))=z
说回到 φ(x) 的构造,我们只能利用 F(z) 本身的性质设法构造出 z:我们有
F(z)=zez−1
相当于有
ln(1+zF(z))=z
不幸的是,这个式子本身还和 z 有关,但启发我们改变方向,设 ω(z)=zF(z),构造多项式 φ(z) 使得
{ω(z)=zF(z)φ(ω(z))=F(z)
有 φ(z)=ln(1+z)z,φ(ω(z))ω(z)=z。即把原式转化为了:
(1−uω(z))(ln(1+ω(z))ω(z)−1)1
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 9, mod = 998244353;
struct z {
int x;
z(int x = 0) : x(x) {}
friend inline z operator*(z a, z b) { return (long long)a.x * b.x % mod; }
friend inline z operator-(z a, z b) { return (a.x -= b.x) < 0 ? a.x + mod : a.x; }
friend inline z operator+(z a, z b) { return (a.x += b.x) >= mod ? a.x - mod : a.x; }
} h[N], ans[N], inv[N], fac[N], ifac[N];
inline z fpow(z a, int b) {
z s = 1;
for (; b; b >>= 1, a = a * a)
if (b & 1) s = s * a;
return s;
}
int len = 1, rev[N << 2];
z w[N << 2];
struct vec : vector<z> {
using vector<z>::vector;
inline vec divx() {
vec res = *this;
return res.erase(res.begin()), res;
}
inline vec setl(size_t len) {
vec res = *this;
return res.resize(len), res;
}
inline vec fun1() {
vec res(this->begin() + 1, this->end());
for (int i = 0; i < res.size(); i++) res[i].x = mod - res[i].x;
return res;
}
};
int init(int n) {
int lim = 1, k = 0;
while (lim < n) lim <<= 1, ++k;
for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
for (; len < lim; len <<= 1) {
z wn = fpow(3, (mod - 1) / (len << 1));
w[len] = 1;
for (int i = 1; i < len; i++) w[i + len] = w[i + len - 1] * wn;
}
return lim;
}
void dft(vec &a, int lim) {
a.resize(lim);
for (int i = 0; i < lim; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int i = 0; i < lim; i += 2) {
z x = a[i], y = a[i + 1] * w[1];
a[i] = x + y, a[i + 1] = x - y;
}
for (int len = 2; len < lim; len <<= 1)
for (int i = 0; i < lim; i += (len << 1))
for (int j = 0; j < len; j++) {
z x = a[i + j], y = a[i + j + len] * w[j + len];
a[i + j] = x + y, a[i + j + len] = x - y;
}
}
void idft(vec &a, int lim) {
z inv = fpow(lim, mod - 2);
dft(a, lim), reverse(&a[1], &a[lim]);
for (int i = 0; i < lim; i++) a[i] = a[i] * inv;
}
vec operator+(vec a, const vec &b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++) a[i] = a[i] + b[i];
return a;
}
vec operator-(vec a, const vec &b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++) a[i] = a[i] - b[i];
return a;
}
vec operator*(vec a, vec b) {
if (a.size() < 20 || b.size() < 20 || (uint64_t)a.size() + b.size() < 400) {
vec c(a.size() + b.size() - 1);
for (int i = 0; i < a.size(); i++)
for (int j = 0; j < b.size(); j++) c[i + j] = c[i + j] + a[i] * b[j];
return c;
}
int len = a.size() + b.size() - 1, lim = init(len);
dft(a, lim), dft(b, lim);
for (int i = 0; i < lim; i++) a[i] = a[i] * b[i];
return idft(a, lim), a.resize(len), a;
}
vec inve(const vec &f, int len = -1) {
if ((len = ~len ? len : f.size()) == 1) return vec{fpow(f[0], mod - 2)};
vec a = inve(f, (len + 1) >> 1), b(&f[0], &f[len]);
int lim = init((len << 1) - 1);
dft(a, lim), dft(b, lim);
for (int i = 0; i < lim; i++) a[i] = a[i] * (2 - a[i] * b[i]);
return idft(a, lim), a.resize(len), a;
}
vec inte(vec a) {
for (int i = a.size() - 1; i; i--) a[i] = a[i - 1] * ::inv[i];
return *a.begin() = 0, a;
}
vec deri(vec a) {
for (int i = 0; i < a.size() - 1; i++) a[i] = a[i + 1] * (i + 1);
return *a.rbegin() = 0, a;
}
vec ln(const vec &f) { return inte((deri(f) * inve(f)).setl(f.size())); }
vec exp(const vec &f, int len = -1) {
if ((len = ~len ? len : f.size()) == 1) return vec{1};
vec a = exp(f, (len + 1) >> 1), b = a;
b.resize(len), b = ln(b);
for (int i = 0; i < len; i++) b[i] = 0 - b[i];
b[0] = b[0] + 1;
for (int i = 0; i < len; i++) b[i] = b[i] + f[i];
return (a * b).setl(len);
}
vec pow(vec a, int b) {
a = ln(a);
for (int i = 0; i < a.size(); i++) a[i] = a[i] * b;
return exp(a);
}
int main() {
#ifdef memset0
freopen("1.in", "r", stdin);
#endif
cin.tie(0)->sync_with_stdio(0);
fac[0] = ifac[0] = inv[0] = inv[1] = 1;
for (int i = 2; i < N; i++) inv[i] = (mod - mod / i) * inv[mod % i];
for (int i = 1; i < N; i++) fac[i] = fac[i - 1] * i, ifac[i] = ifac[i - 1] * inv[i];
int n, len;
cin >> n, len = n + 5;
vec Finv = inve(vec(ifac + 1, ifac + len + 1).fun1());
for (int i = 0; i <= n; i++) h[i] = Finv[i + 1];
h[0] = h[0] - 1;
vec P(len);
for (int i = 0; i < len; i++) P[i] = (i & 1 ? mod - 1 : 1) * inv[i + 1];
P = inve(P);
vec Ppow = pow(P, n + 1);
vec Pinv = inve(P.fun1());
vec lpart = (Ppow * Pinv).setl(len);
vec rpart = ((deri(P) * Pinv).setl(len) * lpart).setl(len);
z tmp = fpow(n + 1, mod - 2);
for (int i = 0; i <= n; i++) h[i] = h[i] - tmp * (lpart[i + 1] * (n - i + 1) + rpart[i + 1]);
vec f(n + 1), g(n + 1);
for (int i = 0; i <= n; i++) f[i] = ((n - i) & 1 ? mod - 1 : 1) * ifac[n - i];
for (int i = 0; i <= n; i++) g[i] = fac[i] * h[i];
f = f * g;
for (int i = 0; i < n; i++) ans[i] = f[i + n] * fac[n] * ifac[i];
for (int i = 0; i < n; i++) cout << ans[i].x << " \n"[i + 1 == n];
}