「是男人就过8题——Pony.ai」IntervalTree
「是男人就过8题——Pony.ai」IntervalTree

2020 年 06 月 02 日

定义区间树为线段树的拓展,即每次断开的位置可以不是线段的中心。

给定一个 [1,n][1, n] 的区间树和 qq 次询问,每次询问包含一个正整数 kk, 你需要求出有多少区间的时间复杂度恰好等于 kk

n,q105, k109n, q\le 10^5,\ k\le 10^9

题解

在线回答询问无意义,考虑利用生成函数处理出所有询问的答案。

询问 [l;r][l;r] 选中的线段(ql=lqr=rql=l \land qr=r 的线段,而非经过的线段),LCA 往两侧深度单调减(且中间平的一段的长度至多为 22)。

求出往两侧单调的生成函数合并,类似:

Lu(x)=x(1+Ll(x)+x(Lr(x)x))=xLl(x)+x2Lr(x)+xx3Ru(x)=x(1+Rr(x)+x(Rl(x)x))=xRr(x)+x2Rl(x)+xx3L_u(x) = x (1 + L_l(x) + x (L_r(x) - x)) = x L_l(x) + x^2 L_r(x) + x - x^3 \\ R_u(x) = x (1 + R_r(x) + x (R_l(x) - x)) = x R_r(x) + x^2 R_l(x) + x - x^3 \\

其中 L,RL,R 分别表示左/右端点和当前线段的左/右端点相同的线段(不包括完全相同的情况)的生成函数,uu 是当前节点,ll 是左儿子,rr 是右儿子。

处理一些平凡情况,在断点计算贡献:

Su(x)=xdepu(Rl(x)Lr(x)x2+1)S_u(x) = x^{dep_u} (R_l(x) L_r(x) - x^2 + 1)

就能做到 polylog×线段长度\text{polylog} \times \sum \small{\text{线段长度}} 复杂度。

进一步优化复杂度,考虑边分治:假设当前处理子树 uu,边分的子树 vv。递归处理出 uvu \leftrightarrow v 的路径上的 S,L,RS,L,R。下面考虑 Lv,RvL_v,R_v 对路径上点的 SSuuL,RL,R 的贡献。

前者可以分别考虑 Lv,RvL_v,R_v 的贡献,通过两次卷积得到。后者则是路径上的 L,RL,R 通过一定位移得到。

至于处理可以分别在两侧继续边分。复杂度 O(nlog2n)O(n \log^2 n)

代码

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, mod = 998244353;
int _, n, m, cnt, siz[N], vis[N], mid[N], ch[N][2], fa[N], l[N], r[N], dep[N], rev[N << 2];
struct z {
  int x;
  z(int x = 0) : x(x) {}
  friend inline z operator*(z a, z b) { return (long long)a.x * b.x % mod; }
  friend inline z operator-(z a, z b) { return (a.x -= b.x) < 0 ? a.x + mod : a.x; }
  friend inline z operator+(z a, z b) { return (a.x += b.x) >= mod ? a.x - mod : a.x; }
} w[N << 2];
vector<z> ans;
inline z fpow(z a, int b) {
  z s = 1;
  for (; b; b >>= 1, a = a * a)
    if (b & 1) s = s * a;
  return s;
}
void dft(vector<z> &a, int lim) {
  a.resize(lim);
  for (int i = 0; i < lim; i++)
    if (i < rev[i]) swap(a[i], a[rev[i]]);
  for (int len = 1; len < lim; len <<= 1)
    for (int i = 0; i < lim; i += (len << 1))
      for (int j = 0; j < len; j++) {
        z x = a[i + j], y = a[i + j + len] * w[j + len];
        a[i + j] = x + y, a[i + j + len] = x - y;
      }
}
vector<z> operator+(vector<z> a, const vector<z> &b) {
  a.resize(max(a.size(), b.size()));
  for (int i = 0; i < b.size(); i++) a[i] = a[i] + b[i];
  return a;
}
vector<z> operator-(vector<z> a, const vector<z> &b) {
  a.resize(max(a.size(), b.size()));
  for (int i = 0; i < b.size(); i++) a[i] = a[i] - b[i];
  return a;
}
vector<z> operator*(vector<z> a, vector<z> b) {
  int len = a.size() + b.size() - 1, lim = 1, k = 0;
  while (lim < len) lim <<= 1, ++k;
  for (int i = 0; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (k - 1));
  dft(a, lim), dft(b, lim);
  for (int i = 0; i < lim; i++) a[i] = a[i] * b[i];
  dft(a, lim), reverse(&a[1], &a[lim]);
  z inv = fpow(lim, mod - 2);
  for (int i = 0; i < lim; i++) a[i] = a[i] * inv;
  return a.resize(len), a;
}
void shift(z x, vector<z> &dst, size_t dta) {
  dst.resize(max(dst.size(), dta + 1));
  dst[dta] = dst[dta] + x;
}
void shift(const vector<z> &src, vector<z> &dst, size_t dta) {
  dst.resize(max(dst.size(), src.size() + dta));
  for (int i = 0; i < src.size(); i++) dst[i + dta] = dst[i + dta] + src[i];
}
void dfsInit(int &u, int l, int r, int dep) {
  if (l == r) {
    u = n - 1 + l;
  } else {
    u = ++cnt;
    dfsInit(ch[u][0], l, mid[u], dep + 1), fa[ch[u][0]] = u;
    dfsInit(ch[u][1], mid[u] + 1, r, dep + 1), fa[ch[u][1]] = u;
  }
  ::l[u] = l, ::r[u] = r, ::dep[u] = dep;
}
int calcSize(int u) {
  if (!u || vis[u]) return 0;
  return siz[u] = 1 + calcSize(ch[u][0]) + calcSize(ch[u][1]);
}
pair<int, int> findSubTree(int u, int lim) {
  if (!u || vis[u]) return {-1, -1};
  if (siz[u] < lim) return {u, siz[u]};
  pair<int, int> x = ch[u][0] ? findSubTree(ch[u][0], lim) : make_pair(-1, -1);
  pair<int, int> y = ch[u][1] ? findSubTree(ch[u][1], lim) : make_pair(-1, -1);
  return x.second > y.second ? x : y;
}
void calc(bool fl, int u, int mov, vector<z> &f) {
  if (vis[u] || l[u] == r[u]) return shift(1, f, mov + 1);
  shift(1, f, mov + 1);
  shift(mod - 1, f, mov + 3);
  calc(fl, ch[u][0], mov + (fl ? 2 : 1), f);
  calc(fl, ch[u][1], mov + (fl ? 1 : 2), f);
}
pair<vector<z>, vector<z>> fuck(int u) {
  if (vis[u] || l[u] == r[u]) return {{0, 1}, {0, 1}};
  vector<z> Ll, Lr, Rl, Rr, Lu, Ru;
  Lu = Ru = vector<z>{0, 1, 0, mod - 1};
  tie(Ll, Rl) = fuck(ch[u][0]);
  tie(Lr, Rr) = fuck(ch[u][1]);
  shift(Ll, Lu, 1), shift(Lr, Lu, 2);
  shift(Rr, Ru, 1), shift(Rl, Ru, 2);
  shift(Rl * Lr, ans, dep[u]);
  return {Lu, Ru};
}
pair<vector<z>, vector<z>> solve(int u) {
  if (vis[u] || l[u] == r[u]) return {{0, 1}, {0, 1}};
  int siz = calcSize(u);
  int v = findSubTree(u, (siz * 2) / 4).first;
  if (v == -1) return fuck(u);
  vector<z> Lu, Ru, Lv, Rv, Lt, Rt, T;
  tie(Lv, Rv) = solve(v);
  vis[v] = 1;
  tie(Lu, Ru) = solve(u);
  vis[v] = 0;
  int lmov = 0, rmov = 0;
  for (int p = v; p != u; lmov += ch[fa[p]][0] == p ? 1 : 2, rmov += ch[fa[p]][0] == p ? 2 : 1, p = fa[p]) {
    int f = fa[p], q = ch[f][0] == p ? ch[f][1] : ch[f][0];
    if (ch[f][0] == p) {
      T.clear(), calc(0, q, 0, T), shift(T, Lt, rmov + dep[f] - dep[u]);
    } else {
      T.clear(), calc(1, q, 0, T), shift(T, Rt, lmov + dep[f] - dep[u]);
    }
  }
  shift(mod - 1, Lv, 1);
  shift(mod - 1, Rv, 1);
  shift(Lv * Rt, ans, dep[u]);
  shift(Rv * Lt, ans, dep[u]);
  shift(Lv, Lu, lmov);
  shift(Rv, Ru, rmov);
  return {Lu, Ru};
}
void solution() {
  for (int i = 1; i < n; i++) cin >> mid[i];
  dfsInit(_, 1, n, 1);
  solve(1);
  for (int i = 1; i < (n << 1); i++) shift(1, ans, dep[i]);
  for (int i = 1; i < n; i++) shift(mod - 1, ans, dep[i] + 2);
  for (int i = 1, q; i <= m; i++) {
    cin >> q;
    cout << (q < ans.size() ? ans[q].x : 0) << '\n';
  }
}
void recycle() {
  cnt = 0;
  ans.clear();
  memset(ch, 0, sizeof(ch));
  memset(fa, 0, sizeof(fa));
}
int main() {
  cin.tie(0)->sync_with_stdio(0);
  for (int len = 1; len < (N << 1); len <<= 1) {
    z wn = fpow(3, (mod - 1) / (len << 1));
    w[len] = 1;
    for (int i = 1; i < len; i++) w[i + len] = w[i + len - 1] * wn;
  }
  while (cin >> n >> m) solution(), recycle();
}

评论

TABLE OF CONTENTS

题解
代码