有 张带编号卡牌,每次你可以随机抽取一张。抽中每张的概率均为 。当编号连续的 张牌都被抽取过时,游戏结束。
问游戏结束的期望步数。
。
题解
Part 1
我们可以直接对每张牌第一次被抽中的操作序列计数。
把牌的每一段编号连续区间分开考虑,每一段处理出选中连续区间长度不超过 的方案数(同时容易得到超过的方案数),然后分治 + NTT 合并,这是平凡的。
这个做法的时间复杂度是 的,瓶颈在于前半部分即处理出分成把 个 段满足每一段长度都不超过 的方案数,更进一步的可以表示为:
我们需要求出多项式 。
Part2
注意到这是一个拓展拉格朗日反演的形式,我们需要求出 的复合逆。 相当于我们要求 满足 ,根据多项式牛顿迭代,有
由多项式牛顿迭代,我们可以倍增得到 。
Part3
代入拓展拉格朗日反演的式子,令 我们可以得到
设 ,则有
即可直接得到 。
问题解决,总时间复杂度 。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, mod = 998244353;
int n, m, k, a[N], b[N], rev[N << 2];
struct z {
int x;
inline z() : x(0) {}
inline z(int x) : x(x) {}
friend inline z operator*(z a, z b) { return (long long)a.x * b.x % mod; }
friend inline z operator-(z a, z b) { return (a.x -= b.x) < 0 ? a.x + mod : a.x; }
friend inline z operator+(z a, z b) { return (a.x += b.x) >= mod ? a.x - mod : a.x; }
} ans, len[N], fac[N], ifac[N], w[N << 2], S[N];
using vec = vector<z>;
inline z C(int n, int m) { return n < m ? 0 : fac[n] * ifac[m] * ifac[n - m]; }
inline z fpow(z a, int b) {
z s = 1;
for (; b; b >>= 1, a = a * a)
if (b & 1) s = s * a;
return s;
}
inline vec resize(vec a, int n) {
a.resize(n);
return a;
}
void initfac(int n) {
fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
for (int i = 2; i <= n; i++) fac[i] = fac[i - 1] * i;
for (int i = 2; i <= n; i++) ifac[i] = (mod - mod / i) * ifac[mod % i];
for (int i = 2; i <= n; i++) ifac[i] = ifac[i - 1] * ifac[i];
}
int init(int n) {
int lim = 1, k = 0;
while (lim < n) lim <<= 1, ++k;
for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
static int len = 1;
for (; len < lim; len <<= 1) {
z wn = fpow(3, (mod - 1) / (len << 1));
w[len] = 1;
for (int i = 1; i < len; i++) w[i + len] = w[i + len - 1] * wn;
}
return lim;
}
void dft(vec &a, int lim) {
a.resize(lim);
for (int i = 0; i < lim; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int len = 1; len < lim; len <<= 1)
for (int i = 0; i < lim; i += (len << 1))
for (int j = 0; j < len; j++) {
z x = a[i + j], y = a[i + j + len] * w[j + len];
a[i + j] = x + y, a[i + j + len] = x - y;
}
}
void idft(vec &a, int lim) {
dft(a, lim), reverse(&a[1], &a[lim]);
z inv = fpow(lim, mod - 2);
for (int i = 0; i < lim; i++) a[i] = a[i] * inv;
}
inline vec mul(vec a, vec b, int l) {
if (a.size() < 10 || b.size() < 10) {
vec c(a.size() + b.size() - 1);
for (int i = 0; i < a.size(); i++)
for (int j = 0; j < b.size(); j++) c[i + j] = c[i + j] + a[i] * b[j];
return c.resize(l), c;
}
int len = a.size() + b.size() - 1, lim = init(len);
dft(a, lim), dft(b, lim);
for (int i = 0; i < lim; i++) a[i] = a[i] * b[i];
return idft(a, lim), a.resize(l), a;
}
inline vec operator*(const vec &a, const vec &b) { return mul(a, b, a.size() + b.size() - 1); }
inline vec operator+(vec a, const vec &b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++) a[i] = a[i] + b[i];
return a;
}
inline vec operator-(vec a, const vec &b) {
a.resize(max(a.size(), b.size()));
for (int i = 0; i < b.size(); i++) a[i] = a[i] - b[i];
return a;
}
vec inv(const vec &f, int len = -1) {
if ((len = ~len ? len : f.size()) == 1) return {fpow(f[0], mod - 2)};
vec a(&f[0], &f[len]), b = inv(f, (len + 1) >> 1);
int lim = init((len << 1) - 1);
dft(a, lim), dft(b, lim);
for (int i = 0; i < lim; i++) a[i] = b[i] * (2 - a[i] * b[i]);
return idft(a, lim), a.resize(len), a;
}
vec deri(vec f) {
for (int i = 0; i <= (int)f.size() - 2; i++) f[i] = f[i + 1] * (i + 1);
return f.back() = 0, f;
}
vec inte(vec f) {
for (int i = (int)f.size() - 1; i >= 1; i--) f[i] = f[i - 1] * fpow(i, mod - 2);
return f.front() = 0, f;
}
vec ln(const vec &f) { return inte(mul(inv(f), deri(f), f.size())); }
vec exp(const vec &f, int len = -1) {
if ((len = ~len ? len : f.size()) == 1) return {1};
vec a(&f[0], &f[len]), b = exp(f, (len + 1) >> 1);
return b.resize(len), mul(b, a + vec{1} - ln(b), len);
}
vec fpow(vec a, int b) {
int n = a.size();
vec s;
for (int c = 0; c < n; c++)
if (a[c].x) {
int l = n - c * b;
if (l <= 0) return s.resize(n), s;
for (int i = 0; i < l; i++) a[i] = a[i + c];
a.resize(l);
a = ln(a);
for (int i = 0; i < l; i++) a[i] = a[i] * b;
a = exp(a), s.resize(c * b);
s.insert(s.end(), a.begin(), a.end());
return s;
}
return a;
}
vec complex(const vec &g) { // F(G(x))
vec s, c = {1};
for (int i = 1; i <= k; i++) c = mul(c, g, g.size()), s = s + c;
return s;
}
vec complex_inv(int len) { // G^{-1}(F(x))
if (len == 1) return {0};
vec g = resize(complex_inv((len + 1) >> 1), len), gk = fpow(g, k), gk1 = mul(gk, g, len);
vec res = g - mul(mul(g - gk1 - vec{0, 1} * (vec{1} - g), vec{1} - g, len), inv(vec{1} - vec{k + 1} * gk + vec{k} * gk1), len);
return res;
}
inline vec sol(int n) { // n+1个球,分m组,每组1~k个。
vec g = complex_inv(n + 1), res(n + 1);
g.erase(g.begin());
g = fpow(inv(g), n + 1) * vec{fpow(n + 1, mod - 2)};
for (int i = 1; i <= n; i++) res[i - 1] = (i + 1) * g[n - i];
reverse(&res[0], &res[n]), res[n] = n + 1 <= k;
return res;
}
pair<vec, vec> solve(int l, int r) {
if (l == r) {
int n = b[l];
vec F(n + 1), G = sol(n);
for (int i = 0; i <= n; i++) {
F[i] = fac[n] * ifac[n - i] - G[i] * fac[i] - (i ? S[i - 1] : 0) * ifac[n - i];
S[i] = (i ? S[i - 1] : 0) + F[i] * fac[n - i];
}
for (int i = 0; i < n; i++) F[i] = F[i + 1] * ifac[i];
return F.pop_back(), pair<vec, vec>{F, G};
}
int mid = (l + r) >> 1;
auto L = solve(l, (l + r) >> 1), R = solve(((l + r) >> 1) + 1, r);
return {L.first * R.second + L.second * R.first, L.second * R.second};
}
int main() {
#ifdef memset0
freopen("1.in", "r", stdin);
#endif
cin >> n >> k;
initfac(n + 5);
for (int i = 1; i <= n; i++) cin >> a[i];
sort(a + 1, a + n + 1);
for (int i = 1, j; i <= n; i = j + 1) {
for (j = i; j < n && a[j + 1] == a[i] + j - i + 1; j++)
;
b[++m] = j - i + 1;
}
auto res = solve(1, m);
for (int i = 1; i <= n; i++) {
len[i] = len[i - 1] + n * fpow(n - i + 1, mod - 2);
ans = ans + res.first[i - 1] * fac[i - 1] * fac[n - i] * len[i];
}
cout << (ans * ifac[n]).x << endl;
}