「UOJ500」任�意基DFT
「UOJ500」任意基DFT

2020 年 03 月 27 日

给定 nn 次多项式

f(x)=i=0naixif(x) = \sum_{i=0}^n a_i x^i

QQ 次询问,第 ii 次询问 f(qi)f(q_i)998244353998244353 取模的值。

其中 qiq_i 是一个一阶线性递推,给定 q0,x,yq_0, x, y ,满足

qn=xqn1+yq_n = x q_{n-1} + y

1n2.5×105, 1Q106, 2x<998244353, 0q0,y<9982443531 \leq n \leq 2.5 \times 10^5, \ 1 \leq Q \leq 10^6, \ 2 \leq x < 998244353, \ 0 \leq q_0, y < 998244353

Solution Part1

首先这玩意儿肯定没法两个 log\log 多点求值,除非你是钱哥哥。

Solution Part2

虽然这东西好像小学生都会,但我自己推的时候怎么莫名搞到特征多项式那个方向去了,越学越傻逼了。。。

{gn+1=xgn+ygn=xgn1+y\left\{ \begin{aligned} g_{n+1} = x g_n + y \\ g_n = x g_{n-1} + y \\ \end{aligned} \right.

两式相减得到

(gn+1gn)=x(gngn1)(g_{n+1}-g_n) = x (g_n - g_{n-1})

不妨设

hn=gn+1gnh_n = g_{n+1} - g_n

可以得到 hh 的通项公式

hn=xnh0=xn((x1)g0+y)h_n = x^n h_0 = x^n ((x-1) g_0 + y)

从而得到 gg 的通项公式

gn=g0+i=0n1hi=g0+xn1x1((x1)g0+y)g_n = g_0 + \sum_{i=0}^{n-1} h_i = g_0 + \frac {x^n-1} {x-1} ((x-1)g_0 + y)

不妨表示成 axn+bax^n + b 的形式,其中

{a=(x1)g0+yx1b=yx1\begin{cases} a = \dfrac {(x-1) g_0 + y} {x-1} \\ b = \dfrac {y} {x-1} \\ \end{cases}

Solution Part3

也就是说我们要对于所有 i[1,Q]i \in [1, Q] 求出 f(axi+b)f(a x^i + b) 的值,不妨设多项式 g(x)=f(ax+b)g(x) = f(ax + b) ,也就是一个多项式平移。

g(x)=f(ax+b)=i=0nai(ax+b)i=i=0naij=0ii!j!(ij)!(ax)jbij=i=0naii!j=0i(ax)jj!bij(ij)!\begin{aligned} g(x) &= f(ax + b) \\ &= \sum_{i=0}^n a_i (ax + b)^i \\ &= \sum_{i=0}^n a_i \sum_{j=0}^i \frac {i!} {j! (i-j)!} (ax)^j b^{i-j} \\ &= \sum_{i=0}^n a_i i! \sum_{j=0}^i \frac {(ax)^j} {j!} \cdot \frac {b^{i-j}} {(i-j)!} \\ \end{aligned}

容易发现可以卷积维护,需要反转一个多项式。

Solution Part4

回到之前的式子,现在我们只需要对于所以 i[1,Q]i \in [1, Q] 求出 g(xi)g(x^i) 的值,相当于求

G(z)=i=0Qzij=0najxij=i=0Qzij=0najx(i+j2)(i2)(j2)=i=0Qzix(i2)j=0najx(j2)x(i+j2)\begin{aligned} G(z) &= \sum_{i=0}^Q z^i \sum_{j=0}^n a_j x^{ij} \\ &= \sum_{i=0}^Q z^i \sum_{j=0}^n a_j x^{\binom {i+j} 2 - \binom i 2 - \binom j 2} \\ &= \sum_{i=0}^Q z^i x^{-\binom i2} \sum_{j=0}^n a_j x^{-\binom j2} \cdot x^{\binom {i+j} 2} \\ \end{aligned}

容易发现可以卷积维护,需要反转一个多项式。

题外话:

一开始想过把 ijij 拆成 (i2)+(j2)(ij2)\binom i 2 + \binom j 2 - \binom {i-j} 2 ,结果 iji-j 的值域有正有负,反而难处理,尽量还是避免这种情况。

Code

#include <bits/stdc++.h>
using namespace std;
const int N = 2.1e6 + 10, mod = 998244353;
int n, q, rev[N];
struct z {
  int x;
  z(int x = 0) : x(x) {}
  friend inline z operator*(z a, z b) { return (long long)a.x * b.x % mod; }
  friend inline z operator-(z a, z b) { return (a.x -= b.x) < 0 ? a.x + mod : a.x; }
  friend inline z operator+(z a, z b) { return (a.x += b.x) >= mod ? a.x - mod : a.x; }
} q0, qx, qy, qc, w[N], f[N], g[N], t1[N], t2[N], t3[N], t4[N], fac[N], ifac[N];
inline z fpow(z a, int b) {
  z s = 1;
  for (; b; b >>= 1, a = a * a)
    if (b & 1) s = s * a;
  return s;
}
void initFac(int n) {
  fac[0] = ifac[0] = ifac[1] = 1;
  for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i;
  for (int i = 2; i <= n; i++) ifac[i] = (mod - mod / i) * ifac[mod % i];
  for (int i = 2; i <= n; i++) ifac[i] = ifac[i - 1] * ifac[i];
}
void initPow(z c, int n, z *a, z *b) {
  a[0] = b[0] = 1;
  for (int i = 1; i < n; i++) a[i] = a[i - 1] * c;
  for (int i = 1; i < n; i++) b[i] = b[i - 1] * a[i - 1];
}
int init(int n) {
  static int wk = 1;
  int k = -1, l = 1;
  while (l < n) l <<= 1, ++k;
  for (int i = 0; i < l; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k);
  for (int &k = wk; k < l; k <<= 1) {
    z c = fpow(3, (mod - 1) / k / 2);
    w[k] = 1;
    for (int i = 1; i < k; i++) w[i + k] = w[i + k - 1] * c;
  }
  return l;
}
void ntt(z *a, z *b, int l) {
  for (int i = 0; i < l; i++) b[rev[i]] = a[i];
  for (int k = 1; k < l; k <<= 1)
    for (int i = 0; i < l; i += (k << 1))
      for (int j = 0; j < k; j++) {
        z x = b[i + j], y = b[i + j + k] * w[j + k];
        b[i + j] = x + y, b[i + j + k] = x - y;
      }
}
void mul(z *a, z *b, z *c, int n, int m) {
  static z t[N];
  int l = init(n + m - 1);
  z v = fpow(l, mod - 2);
  ntt(a, c, l), ntt(b, t, l);
  for (int i = 0; i < l; i++) t[i] = t[i] * c[i];
  reverse(t + 1, t + l), ntt(t, c, l);
  for (int i = 0; i < l; i++) c[i] = c[i] * v;
}
int main() {
#ifdef memset0
  freopen("1.in", "r", stdin);
#endif
  cin.tie(0)->sync_with_stdio(0);
  cin >> n >> q;
  initFac(max(n, q));
  for (int i = 0; i <= n; i++) cin >> f[i].x;
  cin >> q0.x >> qx.x >> qy.x;
  qc = (q0 * (qx - 1) + qy) * fpow(qx - 1, mod - 2);
  const auto move = [&](z *a, z c, z k, int n) {
    int i;
    z t;
    for (i = 0, t = 1; i < n; i++) t1[i] = a[i] * fac[i], t2[n - i - 1] = ifac[i] * t, t = t * c;
    mul(t1, t2, t3, n, n);
    for (i = 0, t = 1; i < n; i++) a[i] = t3[n - 1 + i] * ifac[i] * t, t = t * k;
  };
  memset(t1, 0, sizeof(t1)), memset(t2, 0, sizeof(t2));
  move(f, q0 - qc, qc, n + 1);
  memset(t1, 0, sizeof(t1)), memset(t2, 0, sizeof(t2));
  const auto czt = [&](z *a, z *b, z c, int n, int q) {
    initPow(c, n + q, t3, t1);
    reverse(t1, t1 + n + q);
    initPow(fpow(c, mod - 2), max(n, q), t3, t4);
    for (int i = 0; i < n; i++) t2[i] = a[i] * t4[i];
    mul(t1, t2, t3, n + q, n);
    for (int i = 0; i < q; i++) b[i] = t3[n + q - 1 - i] * t4[i];
  };
  czt(f, g, qx, n + 1, q + 1);
  cout << accumulate(g + 1, g + q + 1, 0, [](int a, z b) { return a ^ b.x; }) << endl;
}

评论

TABLE OF CONTENTS

Solution Part1
Solution Part2
Solution Part3
Solution Part4
Code