多项式求逆

该部分见 多项式求逆

多项式除法|取模

有 $F(x) = G(x)Q(x) + R(x)$,给定 $F(x)$ 与 $G(x)$,求解 $Q(x), R(x)$,其中 $F(x), G(x)$ 最高次数分别为 $n, m$

显然地,$Q(x), R(x)$ 的最高次数必定不超过 $n - m, m - 1$,考虑式子变形
$$
\begin{aligned}
&F(x) = G(x)Q(x) + R(x) \\
\Rightarrow &F(\frac1x) = G(\frac1x)Q(\frac1x) + R(\frac1x) \\
\Rightarrow &x^nF(\frac1x) = x^mG(\frac1x)x^{n - m}Q(\frac1x) + x^{n - m + 1}x^{m - 1}R(\frac1x) \\
\end{aligned}
$$
令 $A_R(x)$ 表示 $A(x)$ 系数翻转后得到的多项式,即 $A(x)$ 系数 $(a_0, a_1, …, a_n)$,$A_R(x)$ 系数 $(a_n, a_{n - 1}, …, a_1)$,有
$$
\begin{aligned}
\Rightarrow &x^nF(\frac1x) = x^mG(\frac1x)x^{n - m}Q(\frac1x) + x^{n - m + 1}x^{m - 1}R(\frac1x) \\
\Rightarrow &F_R(x) = G_R(x)Q_R(x) + x^{n - m + 1}R_R(x) \\
\Rightarrow &F_R(x) \equiv G_R(x)Q_R(x) \pmod{x^{n - m + 1}} \\
\Rightarrow &Q_R(x) \equiv \frac{F_R(x)}{G_R(x)}
\end{aligned}
$$
那么这样就可以求出 $Q(x)$,之后则有 $R(x) = F(x) - G(x)Q(x)$

时间复杂度 $O (n \log n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

#define MOD 998244353
#define g 3

using namespace std;

typedef long long LL;

const int MAXN = 1 << 20;

inline LL power (LL x, int p) {
LL cnt = 1;
while (p) {
if (p & 1) cnt = cnt * x % MOD;
x = x * x % MOD;
p >>= 1;
}
return cnt;
}
const LL invg = power (g, MOD - 2);

int oppo[MAXN]= {0}, limit;
void NTT (LL* a, int inv) {
for (int i = 0; i < limit; i ++)
if (i < oppo[i])
swap (a[i], a[oppo[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
LL ome = power (inv == 1 ? g : invg, (MOD - 1) / (mid << 1));
for (int n = mid << 1, j = 0; j < limit; j += n) {
LL x = 1;
for (int k = 0; k < mid; k ++, x = x * ome % MOD) {
LL a1 = a[j + k], xa2 = x * a[j + mid + k] % MOD;
a[j + k] = (a1 + xa2) % MOD;
a[j + mid + k] = (a1 - xa2 + MOD) % MOD;
}
}
}
}
LL A[MAXN], B[MAXN], f[MAXN]= {0};
void mul (LL* X, LL* Y, int fn, int fm) {
int n, lim;
for (n = 1, lim = 0; n <= fn + fm; n <<= 1, lim ++);
limit = n;
for (int i = 0; i < limit; i ++) oppo[i] = (oppo[i >> 1] >> 1) | ((i & 1) << (lim - 1));
for (int i = 0; i < limit; i ++) A[i] = B[i] = 0;
for (int i = 0; i <= fn; i ++) A[i] = X[i];
for (int i = 0; i <= fm; i ++) B[i] = Y[i];
NTT (A, 1), NTT (B, 1);
for (int i = 0; i < limit; i ++) A[i] = A[i] * B[i] % MOD;
NTT (A, - 1);
LL invn = power (n, MOD - 2);
for (int i = 0; i <= fn + fm; i ++) X[i] = A[i] * invn % MOD;
}
void inverse (int deg, LL* a) {
if (deg == 1) { a[0] = power (f[0], MOD - 2); return ; }
inverse ((deg + 1) >> 1, a);
int n, lim;
for (n = 1, lim = 0; n <= (deg << 1); n <<= 1, lim ++);
limit = n;
for (int i = 0; i < limit; i ++) oppo[i] = (oppo[i >> 1] >> 1) | ((i & 1) << (lim - 1));
for (int i = 0; i < limit; i ++) A[i] = B[i] = 0;
for (int i = 0; i < deg; i ++) A[i] = f[i];
for (int i = 0; i < deg << 1; i ++) B[i] = a[i];
NTT (A, 1), NTT (B, 1);
for (int i = 0; i < limit; i ++) B[i] = B[i] * ((2ll - A[i] * B[i] % MOD + MOD) % MOD) % MOD;
NTT (B, - 1);
LL invn = power (n, MOD - 2);
for (int i = 0; i < deg; i ++) a[i] = B[i] * invn % MOD;
}

int N, M;
LL F[MAXN], G[MAXN], INVG[MAXN];
LL RF[MAXN];

inline int getnum () {
int num = 0; char ch = getchar ();
bool isneg = false;
while (! isdigit (ch)) {
if (ch == '-') isneg = true;
ch = getchar ();
}
while (isdigit (ch)) num = (num << 3) + (num << 1) + ch - '0', ch = getchar ();
return isneg ? - num : num;
}

int main () {
N = getnum (), M = getnum ();
for (int i = 0; i <= N; i ++) RF[i] = F[i] = getnum ();
for (int i = 0; i <= M; i ++) f[i] = G[i] = getnum ();
reverse (RF, RF + N + 1); reverse (f, f + M + 1);
for (int i = N - M + 2; i <= M; i ++) f[i] = 0;
inverse (N - M + 1, INVG); mul (RF, INVG, N, N - M + 1);
reverse (RF, RF + N - M + 1);
for (int i = N - M + 1; i <= N; i ++) RF[i] = 0;
for (int i = 0; i <= N - M; i ++) {
if (i > 0) putchar (' ');
printf ("%lld", RF[i]);
}
puts ("");
mul (RF, G, N - M, M);
for (int i = 0; i < M; i ++) RF[i] = (F[i] - RF[i] + MOD) % MOD;
for (int i = 0; i < M; i ++) {
if (i > 0) putchar (' ');
printf ("%lld", RF[i]);
}
puts ("");

return 0;
}