FFT 可以用来计算多项式乘法,但复数的运算会产生浮点误差。对于只有整数参与的多项式运算,有时,使用数论变换(Number-Theoretic Transform)会是更好的选择。
原根
FFT 中,我们使用单位复根 。我们需要单位复根的以下性质。
- 互不相同,保证点值表示的合法;
- ,用于分治;
- ,用于分治;
- 当 时,,用于逆变换。
在数论中,考虑一个质数 (其中 为 的幂)。定义其原根 为使得 互不相同的数。
性质一
令 ,由于 互不相同,满足性质一。
性质二
由 可知 (),故 ,满足性质二。
性质三
根据费马小定理得
又因为 ,所以 ,而根据性质一可得 ,即 。可推出 ,满足性质三。
性质四
当 时
由性质三中的推论可知,,故 ,性质四成立。
求原根
求一个质数的原根,可以使用枚举法 —— 枚举 ,检验 是否为 的原根。
对于一个数 ,最小的满足 的正整数 一定是 的约数。
证明:假设最小的 不是 的约数,找到 满足 ,由费马小定理可知
,与假设矛盾。
检验时,只需要枚举 的所有约数 ,检验 即可。
inline long long pow(const long long x, const long long n, const long long p) {
long long ans = 1;
for (long long num = x, tmp = n; tmp; tmp >>= 1, num = num * num % p) if (tmp & 1) ans = ans * num % p;
return ans;
}
inline long long root(const long long p) {
for (int i = 2; i <= p; i++) {
int x = p - 1;
bool flag = false;
for (int k = 2; k * k <= p - 1; k++) if (x % k == 0) {
if (pow(i, (p - 1) / k, p) == 1) {
flag = true;
break;
}
while (x % k == 0) x /= k;
}
if (!flag && (x == 1 || pow(i, (p - 1) / x, p) != 1)) {
return i;
}
}
throw;
}
模板
把原有的复数运算改为模意义下的运算即可。
注意 要改为 。
inline long long pow(const long long x, const long long n, const long long p) {
long long ans = 1;
for (long long num = x, tmp = n; tmp; tmp >>= 1, num = num * num % p) if (tmp & 1) ans = ans * num % p;
return ans;
}
inline long long root(const long long p) {
for (int i = 2; i <= p; i++) {
int x = p - 1;
bool flag = false;
for (int k = 2; k * k <= p - 1; k++) if (x % k == 0) {
if (pow(i, (p - 1) / k, p) == 1) {
flag = true;
break;
}
while (x % k == 0) x /= k;
}
if (!flag && (x == 1 || pow(i, (p - 1) / x, p) != 1)) {
return i;
}
}
throw;
}
inline void exgcd(const long long a, const long long b, long long &g, long long &x, long long &y) {
if (!b) g = a, x = 1, y = 0;
else exgcd(b, a % b, g, y, x), y -= x * (a / b);
}
inline long long inv(const long long a, const long long p) {
long long g, x, y;
exgcd(a, p, g, x, y);
return (x + p) % p;
}
struct NumberTheoreticTransform {
long long omega[MAXM_EXTENDED], omegaInverse[MAXM_EXTENDED];
void init(const int n) {
long long g = root(MOD), x = pow(g, (MOD - 1) / n, MOD);
for (int i = 0; i < n; i++) {
assert(i < MAXM_EXTENDED);
omega[i] = (i == 0) ? 1 : omega[i - 1] * x % MOD;
omegaInverse[i] = inv(omega[i], MOD);
}
}
void transform(long long *a, const int n, const long long *omega) {
int k = 0;
while ((1 << k) != n) k++;
for (int i = 0; i < n; i++) {
int t = 0;
for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1));
if (t > i) std::swap(a[i], a[t]);
}
for (int l = 2; l <= n; l *= 2) {
int m = l / 2;
for (long long *p = a; p != a + n; p += l) {
for (int i = 0; i < m; i++) {
long long t = omega[n / l * i] * p[i + m] % MOD;
p[i + m] = (p[i] - t + MOD) % MOD;
(p[i] += t) %= MOD;
}
}
}
}
void dft(long long *a, const int n) {
transform(a, n, omega);
}
void idft(long long *a, const int n) {
transform(a, n, omegaInverse);
long long x = inv(n, MOD);
for (int i = 0; i < n; i++) a[i] = a[i] * x % MOD;
}
} ntt;